Add some auth tests
parent
eb01267643
commit
e6e1984c47
4
auth.go
4
auth.go
|
@ -94,10 +94,10 @@ type extractor []string
|
||||||
func (e extractor) ExtractToken(r *http.Request) (string, error) {
|
func (e extractor) ExtractToken(r *http.Request) (string, error) {
|
||||||
token, _ := request.AuthorizationHeaderExtractor.ExtractToken(r)
|
token, _ := request.AuthorizationHeaderExtractor.ExtractToken(r)
|
||||||
|
|
||||||
// Checks if the token isn't empty and if it contains three dots.
|
// Checks if the token isn't empty and if it contains two dots.
|
||||||
// The former prevents incompatibility with URLs that previously
|
// The former prevents incompatibility with URLs that previously
|
||||||
// used basic auth.
|
// used basic auth.
|
||||||
if token != "" && strings.Count(token, ".") == 3 {
|
if token != "" && strings.Count(token, ".") == 2 {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
package filemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultCredentials = "{\"username\":\"admin\",\"password\":\"admin\"}"
|
||||||
|
|
||||||
|
var authHandlerTests = []struct {
|
||||||
|
Data string
|
||||||
|
Expected int
|
||||||
|
}{
|
||||||
|
{defaultCredentials, http.StatusOK},
|
||||||
|
{"{\"username\":\"admin\",\"password\":\"wrong\"}", http.StatusForbidden},
|
||||||
|
{"{\"username\":\"wrong\",\"password\":\"admin\"}", http.StatusForbidden},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler(t *testing.T) {
|
||||||
|
fm := newTest(t)
|
||||||
|
defer fm.Clean()
|
||||||
|
|
||||||
|
for _, test := range authHandlerTests {
|
||||||
|
req, err := http.NewRequest("POST", "/api/auth/get", strings.NewReader(test.Data))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
fm.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != test.Expected {
|
||||||
|
t.Errorf("Wrong status code: got %v want %v", w.Code, test.Expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenewHandler(t *testing.T) {
|
||||||
|
fm := newTest(t)
|
||||||
|
defer fm.Clean()
|
||||||
|
|
||||||
|
// First, we have to make an auth request to get the user authenticated,
|
||||||
|
r, err := http.NewRequest("POST", "/api/auth/get", strings.NewReader(defaultCredentials))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
fm.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Couldn't authenticate: got %v", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := w.Body.String()
|
||||||
|
|
||||||
|
// Test renew authorization via Authorization Header.
|
||||||
|
r, err = http.NewRequest("GET", "/api/auth/renew", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
fm.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Can't renew auth via header: got %v", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test renew authorization via cookie field.
|
||||||
|
r, err = http.NewRequest("GET", "/api/auth/renew", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.AddCookie(&http.Cookie{
|
||||||
|
Value: token,
|
||||||
|
Name: "auth",
|
||||||
|
Expires: time.Now().Add(1 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
fm.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Can't renew auth via cookie: got %v", w.Code)
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,7 +44,8 @@ func (f plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.Configs[i].ServeWithErrorHTTP(w, r)
|
f.Configs[i].ServeHTTP(w, r)
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.Next.ServeHTTP(w, r)
|
return f.Next.ServeHTTP(w, r)
|
||||||
|
|
|
@ -168,7 +168,8 @@ func (p plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.Configs[i].ServeWithErrorHTTP(w, r)
|
p.Configs[i].ServeHTTP(w, r)
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.Next.ServeHTTP(w, r)
|
return p.Next.ServeHTTP(w, r)
|
||||||
|
|
66
file.go
66
file.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/sha512"
|
"crypto/sha512"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -445,3 +446,68 @@ func editorLanguage(mode string) string {
|
||||||
|
|
||||||
return mode
|
return mode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyFile(source string, dest string) (err error) {
|
||||||
|
sourcefile, err := os.Open(source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer sourcefile.Close()
|
||||||
|
|
||||||
|
destfile, err := os.Create(dest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer destfile.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(destfile, sourcefile)
|
||||||
|
if err == nil {
|
||||||
|
sourceinfo, err := os.Stat(source)
|
||||||
|
if err != nil {
|
||||||
|
err = os.Chmod(dest, sourceinfo.Mode())
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyDir(source string, dest string) (err error) {
|
||||||
|
// get properties of source dir
|
||||||
|
sourceinfo, err := os.Stat(source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create dest dir
|
||||||
|
err = os.MkdirAll(dest, sourceinfo.Mode())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
directory, _ := os.Open(source)
|
||||||
|
objects, err := directory.Readdir(-1)
|
||||||
|
|
||||||
|
for _, obj := range objects {
|
||||||
|
sourcefilepointer := source + "/" + obj.Name()
|
||||||
|
destinationfilepointer := dest + "/" + obj.Name()
|
||||||
|
|
||||||
|
if obj.IsDir() {
|
||||||
|
// create sub-directories - recursively
|
||||||
|
err = copyDir(sourcefilepointer, destinationfilepointer)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// perform copy
|
||||||
|
err = copyFile(sourcefilepointer, destinationfilepointer)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -58,8 +58,7 @@ type FileManager struct {
|
||||||
// Command is a command function.
|
// Command is a command function.
|
||||||
type Command func(r *http.Request, m *FileManager, u *User) error
|
type Command func(r *http.Request, m *FileManager, u *User) error
|
||||||
|
|
||||||
// User contains the configuration for each user. It should be created
|
// User contains the configuration for each user.
|
||||||
// using NewUser on a File Manager instance.
|
|
||||||
type User struct {
|
type User struct {
|
||||||
// ID is the required primary key with auto increment0
|
// ID is the required primary key with auto increment0
|
||||||
ID int `storm:"id,increment"`
|
ID int `storm:"id,increment"`
|
||||||
|
@ -349,15 +348,6 @@ func (m *FileManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeWithErrorHTTP returns the code and error of the request.
|
|
||||||
func (m *FileManager) ServeWithErrorHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|
||||||
return serveHTTP(&RequestContext{
|
|
||||||
FM: m,
|
|
||||||
User: nil,
|
|
||||||
FI: nil,
|
|
||||||
}, w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allowed checks if the user has permission to access a directory/file.
|
// Allowed checks if the user has permission to access a directory/file.
|
||||||
func (u User) Allowed(url string) bool {
|
func (u User) Allowed(url string) bool {
|
||||||
var rule *Rule
|
var rule *Rule
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package filemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/webdav"
|
||||||
|
)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
*FileManager
|
||||||
|
Temp string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t test) Clean() {
|
||||||
|
t.db.Close()
|
||||||
|
os.RemoveAll(t.Temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTest(t *testing.T) *test {
|
||||||
|
temp, err := ioutil.TempDir("", t.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := filepath.Join(temp, "scope")
|
||||||
|
database := filepath.Join(temp, "database.db")
|
||||||
|
|
||||||
|
err = copyDir("./testdata", scope)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error copying the test data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := DefaultUser
|
||||||
|
user.FileSystem = webdav.Dir(scope)
|
||||||
|
|
||||||
|
fm, err := New(database, user)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating a file manager instance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &test{
|
||||||
|
FileManager: fm,
|
||||||
|
Temp: temp,
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue