mirror of https://github.com/XTLS/Xray-core
				
				
				
			
		
			
				
	
	
		
			433 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			433 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
| package command_test
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/golang/mock/gomock"
 | |
| 	"github.com/google/go-cmp/cmp"
 | |
| 	"github.com/google/go-cmp/cmp/cmpopts"
 | |
| 	"github.com/xtls/xray-core/app/router"
 | |
| 	. "github.com/xtls/xray-core/app/router/command"
 | |
| 	"github.com/xtls/xray-core/app/stats"
 | |
| 	"github.com/xtls/xray-core/common"
 | |
| 	"github.com/xtls/xray-core/common/net"
 | |
| 	"github.com/xtls/xray-core/features/routing"
 | |
| 	"github.com/xtls/xray-core/testing/mocks"
 | |
| 	"google.golang.org/grpc"
 | |
| 	"google.golang.org/grpc/test/bufconn"
 | |
| )
 | |
| 
 | |
| func TestServiceSubscribeRoutingStats(t *testing.T) {
 | |
| 	c := stats.NewChannel(&stats.ChannelConfig{
 | |
| 		SubscriberLimit: 1,
 | |
| 		BufferSize:      0,
 | |
| 		Blocking:        true,
 | |
| 	})
 | |
| 	common.Must(c.Start())
 | |
| 	defer c.Close()
 | |
| 
 | |
| 	lis := bufconn.Listen(1024 * 1024)
 | |
| 	bufDialer := func(context.Context, string) (net.Conn, error) {
 | |
| 		return lis.Dial()
 | |
| 	}
 | |
| 
 | |
| 	testCases := []*RoutingContext{
 | |
| 		{InboundTag: "in", OutboundTag: "out"},
 | |
| 		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
 | |
| 		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
 | |
| 		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
 | |
| 		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
 | |
| 		{Protocol: "bittorrent", OutboundTag: "blocked"},
 | |
| 		{User: "example@example.com", OutboundTag: "out"},
 | |
| 		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
 | |
| 	}
 | |
| 	errCh := make(chan error)
 | |
| 
 | |
| 	// Server goroutine
 | |
| 	go func() {
 | |
| 		server := grpc.NewServer()
 | |
| 		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
 | |
| 		errCh <- server.Serve(lis)
 | |
| 	}()
 | |
| 
 | |
| 	// Publisher goroutine
 | |
| 	go func() {
 | |
| 		publishTestCases := func() error {
 | |
| 			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
 | |
| 			defer cancel()
 | |
| 			for { // Wait until there's one subscriber in routing stats channel
 | |
| 				if len(c.Subscribers()) > 0 {
 | |
| 					break
 | |
| 				}
 | |
| 				if ctx.Err() != nil {
 | |
| 					return ctx.Err()
 | |
| 				}
 | |
| 			}
 | |
| 			for _, tc := range testCases {
 | |
| 				c.Publish(context.Background(), AsRoutingRoute(tc))
 | |
| 				time.Sleep(time.Millisecond)
 | |
| 			}
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		if err := publishTestCases(); err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// Client goroutine
 | |
| 	go func() {
 | |
| 		defer lis.Close()
 | |
| 		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
 | |
| 		if err != nil {
 | |
| 			errCh <- err
 | |
| 			return
 | |
| 		}
 | |
| 		defer conn.Close()
 | |
| 		client := NewRoutingServiceClient(conn)
 | |
| 
 | |
| 		// Test retrieving all fields
 | |
| 		testRetrievingAllFields := func() error {
 | |
| 			streamCtx, streamClose := context.WithCancel(context.Background())
 | |
| 
 | |
| 			// Test the unsubscription of stream works well
 | |
| 			defer func() {
 | |
| 				streamClose()
 | |
| 				timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
 | |
| 				defer timeout()
 | |
| 				for { // Wait until there's no subscriber in routing stats channel
 | |
| 					if len(c.Subscribers()) == 0 {
 | |
| 						break
 | |
| 					}
 | |
| 					if timeOutCtx.Err() != nil {
 | |
| 						t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
 | |
| 					}
 | |
| 				}
 | |
| 			}()
 | |
| 
 | |
| 			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 
 | |
| 			for _, tc := range testCases {
 | |
| 				msg, err := stream.Recv()
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 				if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
 | |
| 					t.Error(r)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// Test that double subscription will fail
 | |
| 			errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
 | |
| 				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
 | |
| 			})
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			if _, err := errStream.Recv(); err == nil {
 | |
| 				t.Error("unexpected successful subscription")
 | |
| 			}
 | |
| 
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		if err := testRetrievingAllFields(); err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 		errCh <- nil // Client passed all tests successfully
 | |
| 	}()
 | |
| 
 | |
| 	// Wait for goroutines to complete
 | |
| 	select {
 | |
| 	case <-time.After(2 * time.Second):
 | |
| 		t.Fatal("Test timeout after 2s")
 | |
| 	case err := <-errCh:
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServiceSubscribeSubsetOfFields(t *testing.T) {
 | |
| 	c := stats.NewChannel(&stats.ChannelConfig{
 | |
| 		SubscriberLimit: 1,
 | |
| 		BufferSize:      0,
 | |
| 		Blocking:        true,
 | |
| 	})
 | |
| 	common.Must(c.Start())
 | |
| 	defer c.Close()
 | |
| 
 | |
| 	lis := bufconn.Listen(1024 * 1024)
 | |
| 	bufDialer := func(context.Context, string) (net.Conn, error) {
 | |
| 		return lis.Dial()
 | |
| 	}
 | |
| 
 | |
| 	testCases := []*RoutingContext{
 | |
| 		{InboundTag: "in", OutboundTag: "out"},
 | |
| 		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
 | |
| 		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
 | |
| 		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
 | |
| 		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
 | |
| 		{Protocol: "bittorrent", OutboundTag: "blocked"},
 | |
| 		{User: "example@example.com", OutboundTag: "out"},
 | |
| 		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
 | |
| 	}
 | |
| 	errCh := make(chan error)
 | |
| 
 | |
| 	// Server goroutine
 | |
| 	go func() {
 | |
| 		server := grpc.NewServer()
 | |
| 		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
 | |
| 		errCh <- server.Serve(lis)
 | |
| 	}()
 | |
| 
 | |
| 	// Publisher goroutine
 | |
| 	go func() {
 | |
| 		publishTestCases := func() error {
 | |
| 			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
 | |
| 			defer cancel()
 | |
| 			for { // Wait until there's one subscriber in routing stats channel
 | |
| 				if len(c.Subscribers()) > 0 {
 | |
| 					break
 | |
| 				}
 | |
| 				if ctx.Err() != nil {
 | |
| 					return ctx.Err()
 | |
| 				}
 | |
| 			}
 | |
| 			for _, tc := range testCases {
 | |
| 				c.Publish(context.Background(), AsRoutingRoute(tc))
 | |
| 				time.Sleep(time.Millisecond)
 | |
| 			}
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		if err := publishTestCases(); err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// Client goroutine
 | |
| 	go func() {
 | |
| 		defer lis.Close()
 | |
| 		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
 | |
| 		if err != nil {
 | |
| 			errCh <- err
 | |
| 			return
 | |
| 		}
 | |
| 		defer conn.Close()
 | |
| 		client := NewRoutingServiceClient(conn)
 | |
| 
 | |
| 		// Test retrieving only a subset of fields
 | |
| 		testRetrievingSubsetOfFields := func() error {
 | |
| 			streamCtx, streamClose := context.WithCancel(context.Background())
 | |
| 			defer streamClose()
 | |
| 			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
 | |
| 				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
 | |
| 			})
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 
 | |
| 			for _, tc := range testCases {
 | |
| 				msg, err := stream.Recv()
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 				stat := &RoutingContext{ // Only a subset of stats is retrieved
 | |
| 					SourceIPs:         tc.SourceIPs,
 | |
| 					TargetIPs:         tc.TargetIPs,
 | |
| 					SourcePort:        tc.SourcePort,
 | |
| 					TargetPort:        tc.TargetPort,
 | |
| 					TargetDomain:      tc.TargetDomain,
 | |
| 					OutboundGroupTags: tc.OutboundGroupTags,
 | |
| 					OutboundTag:       tc.OutboundTag,
 | |
| 				}
 | |
| 				if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
 | |
| 					t.Error(r)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			return nil
 | |
| 		}
 | |
| 		if err := testRetrievingSubsetOfFields(); err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 		errCh <- nil // Client passed all tests successfully
 | |
| 	}()
 | |
| 
 | |
| 	// Wait for goroutines to complete
 | |
| 	select {
 | |
| 	case <-time.After(2 * time.Second):
 | |
| 		t.Fatal("Test timeout after 2s")
 | |
| 	case err := <-errCh:
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestSerivceTestRoute(t *testing.T) {
 | |
| 	c := stats.NewChannel(&stats.ChannelConfig{
 | |
| 		SubscriberLimit: 1,
 | |
| 		BufferSize:      16,
 | |
| 		Blocking:        true,
 | |
| 	})
 | |
| 	common.Must(c.Start())
 | |
| 	defer c.Close()
 | |
| 
 | |
| 	r := new(router.Router)
 | |
| 	mockCtl := gomock.NewController(t)
 | |
| 	defer mockCtl.Finish()
 | |
| 	common.Must(r.Init(context.TODO(), &router.Config{
 | |
| 		Rule: []*router.RoutingRule{
 | |
| 			{
 | |
| 				InboundTag: []string{"in"},
 | |
| 				TargetTag:  &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 			{
 | |
| 				Protocol:  []string{"bittorrent"},
 | |
| 				TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
 | |
| 			},
 | |
| 			{
 | |
| 				PortList:  &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
 | |
| 				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 			{
 | |
| 				SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
 | |
| 				TargetTag:      &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 			{
 | |
| 				Domain:    []*router.Domain{{Type: router.Domain_Domain, Value: "com"}},
 | |
| 				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 			{
 | |
| 				SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
 | |
| 				TargetTag:   &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 			{
 | |
| 				UserEmail: []string{"example@example.com"},
 | |
| 				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 			{
 | |
| 				Networks:  []net.Network{net.Network_UDP, net.Network_TCP},
 | |
| 				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
 | |
| 			},
 | |
| 		},
 | |
| 	}, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl)))
 | |
| 
 | |
| 	lis := bufconn.Listen(1024 * 1024)
 | |
| 	bufDialer := func(context.Context, string) (net.Conn, error) {
 | |
| 		return lis.Dial()
 | |
| 	}
 | |
| 
 | |
| 	errCh := make(chan error)
 | |
| 
 | |
| 	// Server goroutine
 | |
| 	go func() {
 | |
| 		server := grpc.NewServer()
 | |
| 		RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
 | |
| 		errCh <- server.Serve(lis)
 | |
| 	}()
 | |
| 
 | |
| 	// Client goroutine
 | |
| 	go func() {
 | |
| 		defer lis.Close()
 | |
| 		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
 | |
| 		if err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 		defer conn.Close()
 | |
| 		client := NewRoutingServiceClient(conn)
 | |
| 
 | |
| 		testCases := []*RoutingContext{
 | |
| 			{InboundTag: "in", OutboundTag: "out"},
 | |
| 			{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
 | |
| 			{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
 | |
| 			{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
 | |
| 			{Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
 | |
| 			{User: "example@example.com", OutboundTag: "out"},
 | |
| 			{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
 | |
| 		}
 | |
| 
 | |
| 		// Test simple TestRoute
 | |
| 		testSimple := func() error {
 | |
| 			for _, tc := range testCases {
 | |
| 				route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 				if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
 | |
| 					t.Error(r)
 | |
| 				}
 | |
| 			}
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		// Test TestRoute with special options
 | |
| 		testOptions := func() error {
 | |
| 			sub, err := c.Subscribe()
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			for _, tc := range testCases {
 | |
| 				route, err := client.TestRoute(context.Background(), &TestRouteRequest{
 | |
| 					RoutingContext: tc,
 | |
| 					FieldSelectors: []string{"ip", "port", "domain", "outbound"},
 | |
| 					PublishResult:  true,
 | |
| 				})
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 				stat := &RoutingContext{ // Only a subset of stats is retrieved
 | |
| 					SourceIPs:         tc.SourceIPs,
 | |
| 					TargetIPs:         tc.TargetIPs,
 | |
| 					SourcePort:        tc.SourcePort,
 | |
| 					TargetPort:        tc.TargetPort,
 | |
| 					TargetDomain:      tc.TargetDomain,
 | |
| 					OutboundGroupTags: tc.OutboundGroupTags,
 | |
| 					OutboundTag:       tc.OutboundTag,
 | |
| 				}
 | |
| 				if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
 | |
| 					t.Error(r)
 | |
| 				}
 | |
| 				select { // Check that routing result has been published to statistics channel
 | |
| 				case msg, received := <-sub:
 | |
| 					if route, ok := msg.(routing.Route); received && ok {
 | |
| 						if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
 | |
| 							t.Error(r)
 | |
| 						}
 | |
| 					} else {
 | |
| 						t.Error("unexpected failure in receiving published routing result for testcase", tc)
 | |
| 					}
 | |
| 				case <-time.After(100 * time.Millisecond):
 | |
| 					t.Error("unexpected failure in receiving published routing result", tc)
 | |
| 				}
 | |
| 			}
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		if err := testSimple(); err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 		if err := testOptions(); err != nil {
 | |
| 			errCh <- err
 | |
| 		}
 | |
| 		errCh <- nil // Client passed all tests successfully
 | |
| 	}()
 | |
| 
 | |
| 	// Wait for goroutines to complete
 | |
| 	select {
 | |
| 	case <-time.After(2 * time.Second):
 | |
| 		t.Fatal("Test timeout after 2s")
 | |
| 	case err := <-errCh:
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}
 | |
| }
 |