diff --git a/infra/conf/dns_test.go b/infra/conf/dns_test.go index d28c2041..5d423884 100644 --- a/infra/conf/dns_test.go +++ b/infra/conf/dns_test.go @@ -2,35 +2,15 @@ package conf_test import ( "encoding/json" - "os" - "path/filepath" "testing" "github.com/xtls/xray-core/app/dns" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" . "github.com/xtls/xray-core/infra/conf" "google.golang.org/protobuf/proto" ) -func init() { - wd, err := os.Getwd() - common.Must(err) - - if _, err := os.Stat(platform.GetAssetLocation("geoip.dat")); err != nil && os.IsNotExist(err) { - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(wd, "..", "..", "resources", "geoip.dat"))) - } - - os.Setenv("xray.location.asset", wd) -} - func TestDNSConfigParsing(t *testing.T) { - defer func() { - os.Unsetenv("xray.location.asset") - }() - parserCreator := func() func(string) (proto.Message, error) { return func(s string) (proto.Message, error) { config := new(DNSConfig) diff --git a/infra/conf/router_test.go b/infra/conf/router_test.go index 340b871c..35c68c68 100644 --- a/infra/conf/router_test.go +++ b/infra/conf/router_test.go @@ -2,6 +2,7 @@ package conf_test import ( "encoding/json" + "fmt" "os" "path/filepath" "testing" @@ -18,21 +19,44 @@ import ( "google.golang.org/protobuf/proto" ) -func init() { - wd, err := os.Getwd() - common.Must(err) - - if _, err := os.Stat(platform.GetAssetLocation("geoip.dat")); err != nil && os.IsNotExist(err) { - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(wd, "..", "..", "resources", "geoip.dat"))) +func getAssetPath(file string) (string, error) { + path := platform.GetAssetLocation(file) + _, err := os.Stat(path) + if os.IsNotExist(err) { + path := filepath.Join("..", "..", "resources", file) + _, err := os.Stat(path) + if os.IsNotExist(err) { + return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file) + } + if err != nil { + return "", fmt.Errorf("can't stat %s: %v", path, err) + } + return path, nil + } + if err != nil { + return "", fmt.Errorf("can't stat %s: %v", path, err) } - os.Setenv("xray.location.asset", wd) + return path, nil } func TestToCidrList(t *testing.T) { - t.Log(os.Getenv("xray.location.asset")) + tempDir, err := os.MkdirTemp("", "test-") + if err != nil { + t.Fatalf("can't create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + geoipPath, err := getAssetPath("geoip.dat") + if err != nil { + t.Fatal(err) + } + + common.Must(filesystem.CopyFile(filepath.Join(tempDir, "geoip.dat"), geoipPath)) + common.Must(filesystem.CopyFile(filepath.Join(tempDir, "geoiptestrouter.dat"), geoipPath)) - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoiptestrouter.dat"), "geoip.dat")) + os.Setenv("xray.location.asset", tempDir) + defer os.Unsetenv("xray.location.asset") ips := StringList([]string{ "geoip:us", @@ -44,7 +68,7 @@ func TestToCidrList(t *testing.T) { "ext-ip:geoiptestrouter.dat:!ca", }) - _, err := ToCidrList(ips) + _, err = ToCidrList(ips) if err != nil { t.Fatalf("Failed to parse geoip list, got %s", err) }