Add some auth tests

pull/157/head
Henrique Dias 2017-07-25 11:57:27 +01:00
parent eb01267643
commit e6e1984c47
No known key found for this signature in database
GPG Key ID: 936F5EB68D786730
7 changed files with 214 additions and 15 deletions

View File

@ -94,10 +94,10 @@ type extractor []string
func (e extractor) ExtractToken(r *http.Request) (string, error) {
token, _ := request.AuthorizationHeaderExtractor.ExtractToken(r)
// Checks if the token isn't empty and if it contains three dots.
// Checks if the token isn't empty and if it contains two dots.
// The former prevents incompatibility with URLs that previously
// used basic auth.
if token != "" && strings.Count(token, ".") == 3 {
if token != "" && strings.Count(token, ".") == 2 {
return token, nil
}

92
auth_test.go Normal file
View File

@ -0,0 +1,92 @@
package filemanager
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
var defaultCredentials = "{\"username\":\"admin\",\"password\":\"admin\"}"
var authHandlerTests = []struct {
Data string
Expected int
}{
{defaultCredentials, http.StatusOK},
{"{\"username\":\"admin\",\"password\":\"wrong\"}", http.StatusForbidden},
{"{\"username\":\"wrong\",\"password\":\"admin\"}", http.StatusForbidden},
}
func TestAuthHandler(t *testing.T) {
fm := newTest(t)
defer fm.Clean()
for _, test := range authHandlerTests {
req, err := http.NewRequest("POST", "/api/auth/get", strings.NewReader(test.Data))
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
fm.ServeHTTP(w, req)
if w.Code != test.Expected {
t.Errorf("Wrong status code: got %v want %v", w.Code, test.Expected)
}
}
}
func TestRenewHandler(t *testing.T) {
fm := newTest(t)
defer fm.Clean()
// First, we have to make an auth request to get the user authenticated,
r, err := http.NewRequest("POST", "/api/auth/get", strings.NewReader(defaultCredentials))
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
fm.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Couldn't authenticate: got %v", w.Code)
}
token := w.Body.String()
// Test renew authorization via Authorization Header.
r, err = http.NewRequest("GET", "/api/auth/renew", nil)
if err != nil {
t.Fatal(err)
}
r.Header.Set("Authorization", "Bearer "+token)
w = httptest.NewRecorder()
fm.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Can't renew auth via header: got %v", w.Code)
}
// Test renew authorization via cookie field.
r, err = http.NewRequest("GET", "/api/auth/renew", nil)
if err != nil {
t.Fatal(err)
}
r.AddCookie(&http.Cookie{
Value: token,
Name: "auth",
Expires: time.Now().Add(1 * time.Hour),
})
w = httptest.NewRecorder()
fm.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Can't renew auth via cookie: got %v", w.Code)
}
}

View File

@ -44,7 +44,8 @@ func (f plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
continue
}
return f.Configs[i].ServeWithErrorHTTP(w, r)
f.Configs[i].ServeHTTP(w, r)
return 0, nil
}
return f.Next.ServeHTTP(w, r)

View File

@ -168,7 +168,8 @@ func (p plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
continue
}
return p.Configs[i].ServeWithErrorHTTP(w, r)
p.Configs[i].ServeHTTP(w, r)
return 0, nil
}
return p.Next.ServeHTTP(w, r)

66
file.go
View File

@ -9,6 +9,7 @@ import (
"crypto/sha512"
"encoding/hex"
"errors"
"fmt"
"hash"
"io"
"io/ioutil"
@ -445,3 +446,68 @@ func editorLanguage(mode string) string {
return mode
}
func copyFile(source string, dest string) (err error) {
sourcefile, err := os.Open(source)
if err != nil {
return err
}
defer sourcefile.Close()
destfile, err := os.Create(dest)
if err != nil {
return err
}
defer destfile.Close()
_, err = io.Copy(destfile, sourcefile)
if err == nil {
sourceinfo, err := os.Stat(source)
if err != nil {
err = os.Chmod(dest, sourceinfo.Mode())
}
}
return
}
func copyDir(source string, dest string) (err error) {
// get properties of source dir
sourceinfo, err := os.Stat(source)
if err != nil {
return err
}
// create dest dir
err = os.MkdirAll(dest, sourceinfo.Mode())
if err != nil {
return err
}
directory, _ := os.Open(source)
objects, err := directory.Readdir(-1)
for _, obj := range objects {
sourcefilepointer := source + "/" + obj.Name()
destinationfilepointer := dest + "/" + obj.Name()
if obj.IsDir() {
// create sub-directories - recursively
err = copyDir(sourcefilepointer, destinationfilepointer)
if err != nil {
fmt.Println(err)
}
} else {
// perform copy
err = copyFile(sourcefilepointer, destinationfilepointer)
if err != nil {
fmt.Println(err)
}
}
}
return
}

View File

@ -58,8 +58,7 @@ type FileManager struct {
// Command is a command function.
type Command func(r *http.Request, m *FileManager, u *User) error
// User contains the configuration for each user. It should be created
// using NewUser on a File Manager instance.
// User contains the configuration for each user.
type User struct {
// ID is the required primary key with auto increment0
ID int `storm:"id,increment"`
@ -349,15 +348,6 @@ func (m *FileManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
// ServeWithErrorHTTP returns the code and error of the request.
func (m *FileManager) ServeWithErrorHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return serveHTTP(&RequestContext{
FM: m,
User: nil,
FI: nil,
}, w, r)
}
// Allowed checks if the user has permission to access a directory/file.
func (u User) Allowed(url string) bool {
var rule *Rule

49
filemanager_test.go Normal file
View File

@ -0,0 +1,49 @@
package filemanager
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"golang.org/x/net/webdav"
)
type test struct {
*FileManager
Temp string
}
func (t test) Clean() {
t.db.Close()
os.RemoveAll(t.Temp)
}
func newTest(t *testing.T) *test {
temp, err := ioutil.TempDir("", t.Name())
if err != nil {
t.Fatalf("Error creating temporary directory: %v", err)
}
scope := filepath.Join(temp, "scope")
database := filepath.Join(temp, "database.db")
err = copyDir("./testdata", scope)
if err != nil {
t.Fatalf("Error copying the test data: %v", err)
}
user := DefaultUser
user.FileSystem = webdav.Dir(scope)
fm, err := New(database, user)
if err != nil {
t.Fatalf("Error creating a file manager instance: %v", err)
}
return &test{
FileManager: fm,
Temp: temp,
}
}