diff --git a/.github/linters/.golangci.yml b/.github/linters/.golangci.yml index 19167f21..861db321 100644 --- a/.github/linters/.golangci.yml +++ b/.github/linters/.golangci.yml @@ -1,9 +1,21 @@ run: + timeout: 5m skip-dirs: - external skip-files: - generated.* +issues: + new: true + exclude-rules: + - path: _test\.go + linters: + - gocyclo + - errcheck + - dupl + - gosec + - goconst + linters: enable: - bodyclose @@ -13,22 +25,18 @@ linters: - dupl - errcheck - exhaustive - - funlen - - gochecknoinits - goconst - gocritic - gocyclo - gofmt - goimports - golint - - gomnd - goprintffuncname - gosec - gosimple - govet - ineffassign - interfacer - - lll - misspell - nakedret - noctx @@ -44,6 +52,3 @@ linters: - unused - varcheck - whitespace - -issues: - new: true \ No newline at end of file diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index cdf33594..c21f1297 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -1,4 +1,4 @@ -name: "CodeQL" +name: CodeQL on: push: @@ -6,8 +6,8 @@ on: paths: - "**/*.go" pull_request: - # The branches below must be a subset of the branches above branches: [master] + types: [opened, synchronize, reopened] paths: - "**/*.go" schedule: @@ -15,9 +15,7 @@ on: jobs: analyze: - name: Analyze runs-on: ubuntu-latest - strategy: fail-fast: false matrix: diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index bdf74bb1..310d0378 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -28,3 +28,4 @@ jobs: with: version: v1.31 args: --config=.github/linters/.golangci.yml + only-new-issues: true diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index c5aaf84f..d74f243e 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -266,7 +266,8 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. } if d.router != nil && !skipRoutePick { - if tag, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil { + if route, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil { + tag := route.GetOutboundTag() if h := d.ohm.GetHandler(tag); h != nil { newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) handler = h diff --git a/app/dns/server.go b/app/dns/server.go index 2810d28f..aaf8c240 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -99,13 +99,22 @@ func New(ctx context.Context, config *Config) (*Server, error) { address := endpoint.Address.AsAddress() if address.Family().IsDomain() && address.Domain() == "localhost" { server.clients = append(server.clients, NewLocalNameServer()) - if len(ns.PrioritizedDomain) == 0 { // Priotize local domain with .local domain or without any dot to local DNS - ns.PrioritizedDomain = []*NameServer_PriorityDomain{ - {Type: DomainMatchingType_Regex, Domain: "^[^.]*$"}, // This will only match domain without any dot - {Type: DomainMatchingType_Subdomain, Domain: "local"}, - {Type: DomainMatchingType_Subdomain, Domain: "localdomain"}, - } + // Priotize local domains with specific TLDs or without any dot to local DNS + // References: + // https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml + // https://unix.stackexchange.com/questions/92441/whats-the-difference-between-local-home-and-lan + localTLDsAndDotlessDomains := []*NameServer_PriorityDomain{ + {Type: DomainMatchingType_Regex, Domain: "^[^.]+$"}, // This will only match domains without any dot + {Type: DomainMatchingType_Subdomain, Domain: "local"}, + {Type: DomainMatchingType_Subdomain, Domain: "localdomain"}, + {Type: DomainMatchingType_Subdomain, Domain: "localhost"}, + {Type: DomainMatchingType_Subdomain, Domain: "lan"}, + {Type: DomainMatchingType_Subdomain, Domain: "home.arpa"}, + {Type: DomainMatchingType_Subdomain, Domain: "example"}, + {Type: DomainMatchingType_Subdomain, Domain: "invalid"}, + {Type: DomainMatchingType_Subdomain, Domain: "test"}, } + ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...) } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") { // URI schemed string treated as domain // DOH Local mode diff --git a/app/router/command/command.go b/app/router/command/command.go new file mode 100644 index 00000000..6add0441 --- /dev/null +++ b/app/router/command/command.go @@ -0,0 +1,90 @@ +// +build !confonly + +package command + +//go:generate errorgen + +import ( + "context" + + "google.golang.org/grpc" + + "v2ray.com/core" + "v2ray.com/core/common" + "v2ray.com/core/features/routing" + "v2ray.com/core/features/stats" +) + +// routingServer is an implementation of RoutingService. +type routingServer struct { + router routing.Router + routingStats stats.Channel +} + +// NewRoutingServer creates a statistics service with statistics manager. +func NewRoutingServer(router routing.Router, routingStats stats.Channel) RoutingServiceServer { + return &routingServer{ + router: router, + routingStats: routingStats, + } +} + +func (s *routingServer) TestRoute(ctx context.Context, request *TestRouteRequest) (*RoutingContext, error) { + if request.RoutingContext == nil { + return nil, newError("Invalid routing request.") + } + route, err := s.router.PickRoute(AsRoutingContext(request.RoutingContext)) + if err != nil { + return nil, err + } + if request.PublishResult && s.routingStats != nil { + s.routingStats.Publish(route) + } + return AsProtobufMessage(request.FieldSelectors)(route), nil +} + +func (s *routingServer) SubscribeRoutingStats(request *SubscribeRoutingStatsRequest, stream RoutingService_SubscribeRoutingStatsServer) error { + if s.routingStats == nil { + return newError("Routing statistics not enabled.") + } + genMessage := AsProtobufMessage(request.FieldSelectors) + subscriber, err := stats.SubscribeRunnableChannel(s.routingStats) + if err != nil { + return err + } + defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber) // nolint: errcheck + for { + select { + case value, received := <-subscriber: + route, ok := value.(routing.Route) + if !(received && ok) { + return newError("Receiving upstream statistics failed.") + } + err := stream.Send(genMessage(route)) + if err != nil { + return err + } + case <-stream.Context().Done(): + return stream.Context().Err() + } + } +} + +func (s *routingServer) mustEmbedUnimplementedRoutingServiceServer() {} + +type service struct { + v *core.Instance +} + +func (s *service) Register(server *grpc.Server) { + common.Must(s.v.RequireFeatures(func(router routing.Router, stats stats.Manager) { + RegisterRoutingServiceServer(server, NewRoutingServer(router, nil)) + })) +} + +func init() { + common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, cfg interface{}) (interface{}, error) { + s := core.MustFromContext(ctx) + return &service{v: s}, nil + })) +} diff --git a/app/router/command/command.pb.go b/app/router/command/command.pb.go new file mode 100644 index 00000000..2c3691b2 --- /dev/null +++ b/app/router/command/command.pb.go @@ -0,0 +1,525 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: app/router/command/command.proto + +package command + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + net "v2ray.com/core/common/net" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +// RoutingContext is the context with information relative to routing process. +// It conforms to the structure of v2ray.core.features.routing.Context and v2ray.core.features.routing.Route. +type RoutingContext struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + InboundTag string `protobuf:"bytes,1,opt,name=InboundTag,proto3" json:"InboundTag,omitempty"` + Network net.Network `protobuf:"varint,2,opt,name=Network,proto3,enum=v2ray.core.common.net.Network" json:"Network,omitempty"` + SourceIPs [][]byte `protobuf:"bytes,3,rep,name=SourceIPs,proto3" json:"SourceIPs,omitempty"` + TargetIPs [][]byte `protobuf:"bytes,4,rep,name=TargetIPs,proto3" json:"TargetIPs,omitempty"` + SourcePort uint32 `protobuf:"varint,5,opt,name=SourcePort,proto3" json:"SourcePort,omitempty"` + TargetPort uint32 `protobuf:"varint,6,opt,name=TargetPort,proto3" json:"TargetPort,omitempty"` + TargetDomain string `protobuf:"bytes,7,opt,name=TargetDomain,proto3" json:"TargetDomain,omitempty"` + Protocol string `protobuf:"bytes,8,opt,name=Protocol,proto3" json:"Protocol,omitempty"` + User string `protobuf:"bytes,9,opt,name=User,proto3" json:"User,omitempty"` + Attributes map[string]string `protobuf:"bytes,10,rep,name=Attributes,proto3" json:"Attributes,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + OutboundGroupTags []string `protobuf:"bytes,11,rep,name=OutboundGroupTags,proto3" json:"OutboundGroupTags,omitempty"` + OutboundTag string `protobuf:"bytes,12,opt,name=OutboundTag,proto3" json:"OutboundTag,omitempty"` +} + +func (x *RoutingContext) Reset() { + *x = RoutingContext{} + if protoimpl.UnsafeEnabled { + mi := &file_app_router_command_command_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RoutingContext) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RoutingContext) ProtoMessage() {} + +func (x *RoutingContext) ProtoReflect() protoreflect.Message { + mi := &file_app_router_command_command_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RoutingContext.ProtoReflect.Descriptor instead. +func (*RoutingContext) Descriptor() ([]byte, []int) { + return file_app_router_command_command_proto_rawDescGZIP(), []int{0} +} + +func (x *RoutingContext) GetInboundTag() string { + if x != nil { + return x.InboundTag + } + return "" +} + +func (x *RoutingContext) GetNetwork() net.Network { + if x != nil { + return x.Network + } + return net.Network_Unknown +} + +func (x *RoutingContext) GetSourceIPs() [][]byte { + if x != nil { + return x.SourceIPs + } + return nil +} + +func (x *RoutingContext) GetTargetIPs() [][]byte { + if x != nil { + return x.TargetIPs + } + return nil +} + +func (x *RoutingContext) GetSourcePort() uint32 { + if x != nil { + return x.SourcePort + } + return 0 +} + +func (x *RoutingContext) GetTargetPort() uint32 { + if x != nil { + return x.TargetPort + } + return 0 +} + +func (x *RoutingContext) GetTargetDomain() string { + if x != nil { + return x.TargetDomain + } + return "" +} + +func (x *RoutingContext) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +func (x *RoutingContext) GetUser() string { + if x != nil { + return x.User + } + return "" +} + +func (x *RoutingContext) GetAttributes() map[string]string { + if x != nil { + return x.Attributes + } + return nil +} + +func (x *RoutingContext) GetOutboundGroupTags() []string { + if x != nil { + return x.OutboundGroupTags + } + return nil +} + +func (x *RoutingContext) GetOutboundTag() string { + if x != nil { + return x.OutboundTag + } + return "" +} + +// SubscribeRoutingStatsRequest subscribes to routing statistics channel if opened by v2ray-core. +// * FieldSelectors selects a subset of fields in routing statistics to return. Valid selectors: +// - inbound: Selects connection's inbound tag. +// - network: Selects connection's network. +// - ip: Equivalent as "ip_source" and "ip_target", selects both source and target IP. +// - port: Equivalent as "port_source" and "port_target", selects both source and target port. +// - domain: Selects target domain. +// - protocol: Select connection's protocol. +// - user: Select connection's inbound user email. +// - attributes: Select connection's additional attributes. +// - outbound: Equivalent as "outbound" and "outbound_group", select both outbound tag and outbound group tags. +// * If FieldSelectors is left empty, all fields will be returned. +type SubscribeRoutingStatsRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + FieldSelectors []string `protobuf:"bytes,1,rep,name=FieldSelectors,proto3" json:"FieldSelectors,omitempty"` +} + +func (x *SubscribeRoutingStatsRequest) Reset() { + *x = SubscribeRoutingStatsRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_app_router_command_command_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SubscribeRoutingStatsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SubscribeRoutingStatsRequest) ProtoMessage() {} + +func (x *SubscribeRoutingStatsRequest) ProtoReflect() protoreflect.Message { + mi := &file_app_router_command_command_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SubscribeRoutingStatsRequest.ProtoReflect.Descriptor instead. +func (*SubscribeRoutingStatsRequest) Descriptor() ([]byte, []int) { + return file_app_router_command_command_proto_rawDescGZIP(), []int{1} +} + +func (x *SubscribeRoutingStatsRequest) GetFieldSelectors() []string { + if x != nil { + return x.FieldSelectors + } + return nil +} + +// TestRouteRequest manually tests a routing result according to the routing context message. +// * RoutingContext is the routing message without outbound information. +// * FieldSelectors selects the fields to return in the routing result. All fields are returned if left empty. +// * PublishResult broadcasts the routing result to routing statistics channel if set true. +type TestRouteRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + RoutingContext *RoutingContext `protobuf:"bytes,1,opt,name=RoutingContext,proto3" json:"RoutingContext,omitempty"` + FieldSelectors []string `protobuf:"bytes,2,rep,name=FieldSelectors,proto3" json:"FieldSelectors,omitempty"` + PublishResult bool `protobuf:"varint,3,opt,name=PublishResult,proto3" json:"PublishResult,omitempty"` +} + +func (x *TestRouteRequest) Reset() { + *x = TestRouteRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_app_router_command_command_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TestRouteRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestRouteRequest) ProtoMessage() {} + +func (x *TestRouteRequest) ProtoReflect() protoreflect.Message { + mi := &file_app_router_command_command_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestRouteRequest.ProtoReflect.Descriptor instead. +func (*TestRouteRequest) Descriptor() ([]byte, []int) { + return file_app_router_command_command_proto_rawDescGZIP(), []int{2} +} + +func (x *TestRouteRequest) GetRoutingContext() *RoutingContext { + if x != nil { + return x.RoutingContext + } + return nil +} + +func (x *TestRouteRequest) GetFieldSelectors() []string { + if x != nil { + return x.FieldSelectors + } + return nil +} + +func (x *TestRouteRequest) GetPublishResult() bool { + if x != nil { + return x.PublishResult + } + return false +} + +type Config struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *Config) Reset() { + *x = Config{} + if protoimpl.UnsafeEnabled { + mi := &file_app_router_command_command_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_app_router_command_command_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_app_router_command_command_proto_rawDescGZIP(), []int{3} +} + +var File_app_router_command_command_proto protoreflect.FileDescriptor + +var file_app_router_command_command_proto_rawDesc = []byte{ + 0x0a, 0x20, 0x61, 0x70, 0x70, 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x6d, + 0x6d, 0x61, 0x6e, 0x64, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x12, 0x1d, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, + 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, + 0x64, 0x1a, 0x18, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x6e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa8, 0x04, 0x0a, 0x0e, + 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x1e, + 0x0a, 0x0a, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x54, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0a, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x54, 0x61, 0x67, 0x12, 0x38, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x1e, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, + 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x49, 0x50, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x6f, 0x75, + 0x72, 0x63, 0x65, 0x49, 0x50, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, + 0x49, 0x50, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x09, 0x54, 0x61, 0x72, 0x67, 0x65, + 0x74, 0x49, 0x50, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, + 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x50, 0x6f, + 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x54, 0x61, 0x72, 0x67, + 0x65, 0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x55, 0x73, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x5d, 0x0a, 0x0a, 0x41, 0x74, 0x74, 0x72, + 0x69, 0x62, 0x75, 0x74, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x3d, 0x2e, 0x76, + 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, 0x52, 0x6f, 0x75, + 0x74, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x72, + 0x69, 0x62, 0x75, 0x74, 0x65, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0a, 0x41, 0x74, 0x74, + 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x11, 0x4f, 0x75, 0x74, 0x62, 0x6f, + 0x75, 0x6e, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x54, 0x61, 0x67, 0x73, 0x18, 0x0b, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x11, 0x4f, 0x75, 0x74, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x54, 0x61, 0x67, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x4f, 0x75, 0x74, 0x62, 0x6f, 0x75, 0x6e, + 0x64, 0x54, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x4f, 0x75, 0x74, 0x62, + 0x6f, 0x75, 0x6e, 0x64, 0x54, 0x61, 0x67, 0x1a, 0x3d, 0x0a, 0x0f, 0x41, 0x74, 0x74, 0x72, 0x69, + 0x62, 0x75, 0x74, 0x65, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x46, 0x0a, 0x1c, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, + 0x69, 0x62, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x53, 0x74, 0x61, 0x74, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x53, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, + 0x46, 0x69, 0x65, 0x6c, 0x64, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x22, 0xb7, + 0x01, 0x0a, 0x10, 0x54, 0x65, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x55, 0x0a, 0x0e, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x43, 0x6f, + 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x76, 0x32, + 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, + 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, 0x52, 0x6f, 0x75, 0x74, + 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x52, 0x0e, 0x52, 0x6f, 0x75, 0x74, + 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x46, 0x69, + 0x65, 0x6c, 0x64, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x6f, + 0x72, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x52, 0x65, 0x73, + 0x75, 0x6c, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x75, 0x62, 0x6c, 0x69, + 0x73, 0x68, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x08, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x32, 0x89, 0x02, 0x0a, 0x0e, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x87, 0x01, 0x0a, 0x15, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, + 0x69, 0x62, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x53, 0x74, 0x61, 0x74, 0x73, 0x12, + 0x3b, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, + 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, + 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, + 0x53, 0x74, 0x61, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x76, + 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, 0x52, 0x6f, 0x75, + 0x74, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, + 0x6d, 0x0a, 0x09, 0x54, 0x65, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x2f, 0x2e, 0x76, + 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, 0x54, 0x65, 0x73, + 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, + 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x2e, 0x52, 0x6f, + 0x75, 0x74, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x22, 0x00, 0x42, 0x68, + 0x0a, 0x21, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, + 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, + 0x61, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x21, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, + 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0xaa, 0x02, 0x1d, 0x56, 0x32, 0x52, 0x61, 0x79, + 0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x72, + 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_app_router_command_command_proto_rawDescOnce sync.Once + file_app_router_command_command_proto_rawDescData = file_app_router_command_command_proto_rawDesc +) + +func file_app_router_command_command_proto_rawDescGZIP() []byte { + file_app_router_command_command_proto_rawDescOnce.Do(func() { + file_app_router_command_command_proto_rawDescData = protoimpl.X.CompressGZIP(file_app_router_command_command_proto_rawDescData) + }) + return file_app_router_command_command_proto_rawDescData +} + +var file_app_router_command_command_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_app_router_command_command_proto_goTypes = []interface{}{ + (*RoutingContext)(nil), // 0: v2ray.core.app.router.command.RoutingContext + (*SubscribeRoutingStatsRequest)(nil), // 1: v2ray.core.app.router.command.SubscribeRoutingStatsRequest + (*TestRouteRequest)(nil), // 2: v2ray.core.app.router.command.TestRouteRequest + (*Config)(nil), // 3: v2ray.core.app.router.command.Config + nil, // 4: v2ray.core.app.router.command.RoutingContext.AttributesEntry + (net.Network)(0), // 5: v2ray.core.common.net.Network +} +var file_app_router_command_command_proto_depIdxs = []int32{ + 5, // 0: v2ray.core.app.router.command.RoutingContext.Network:type_name -> v2ray.core.common.net.Network + 4, // 1: v2ray.core.app.router.command.RoutingContext.Attributes:type_name -> v2ray.core.app.router.command.RoutingContext.AttributesEntry + 0, // 2: v2ray.core.app.router.command.TestRouteRequest.RoutingContext:type_name -> v2ray.core.app.router.command.RoutingContext + 1, // 3: v2ray.core.app.router.command.RoutingService.SubscribeRoutingStats:input_type -> v2ray.core.app.router.command.SubscribeRoutingStatsRequest + 2, // 4: v2ray.core.app.router.command.RoutingService.TestRoute:input_type -> v2ray.core.app.router.command.TestRouteRequest + 0, // 5: v2ray.core.app.router.command.RoutingService.SubscribeRoutingStats:output_type -> v2ray.core.app.router.command.RoutingContext + 0, // 6: v2ray.core.app.router.command.RoutingService.TestRoute:output_type -> v2ray.core.app.router.command.RoutingContext + 5, // [5:7] is the sub-list for method output_type + 3, // [3:5] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_app_router_command_command_proto_init() } +func file_app_router_command_command_proto_init() { + if File_app_router_command_command_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_app_router_command_command_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RoutingContext); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_app_router_command_command_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SubscribeRoutingStatsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_app_router_command_command_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TestRouteRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_app_router_command_command_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Config); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_app_router_command_command_proto_rawDesc, + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_app_router_command_command_proto_goTypes, + DependencyIndexes: file_app_router_command_command_proto_depIdxs, + MessageInfos: file_app_router_command_command_proto_msgTypes, + }.Build() + File_app_router_command_command_proto = out.File + file_app_router_command_command_proto_rawDesc = nil + file_app_router_command_command_proto_goTypes = nil + file_app_router_command_command_proto_depIdxs = nil +} diff --git a/app/router/command/command.proto b/app/router/command/command.proto new file mode 100644 index 00000000..84210a2e --- /dev/null +++ b/app/router/command/command.proto @@ -0,0 +1,59 @@ +syntax = "proto3"; + +package v2ray.core.app.router.command; +option csharp_namespace = "V2Ray.Core.App.Router.Command"; +option go_package = "v2ray.com/core/app/router/command"; +option java_package = "com.v2ray.core.app.router.command"; +option java_multiple_files = true; + +import "common/net/network.proto"; + +// RoutingContext is the context with information relative to routing process. +// It conforms to the structure of v2ray.core.features.routing.Context and v2ray.core.features.routing.Route. +message RoutingContext { + string InboundTag = 1; + v2ray.core.common.net.Network Network = 2; + repeated bytes SourceIPs = 3; + repeated bytes TargetIPs = 4; + uint32 SourcePort = 5; + uint32 TargetPort = 6; + string TargetDomain = 7; + string Protocol = 8; + string User = 9; + map Attributes = 10; + repeated string OutboundGroupTags = 11; + string OutboundTag = 12; +} + +// SubscribeRoutingStatsRequest subscribes to routing statistics channel if opened by v2ray-core. +// * FieldSelectors selects a subset of fields in routing statistics to return. Valid selectors: +// - inbound: Selects connection's inbound tag. +// - network: Selects connection's network. +// - ip: Equivalent as "ip_source" and "ip_target", selects both source and target IP. +// - port: Equivalent as "port_source" and "port_target", selects both source and target port. +// - domain: Selects target domain. +// - protocol: Select connection's protocol. +// - user: Select connection's inbound user email. +// - attributes: Select connection's additional attributes. +// - outbound: Equivalent as "outbound" and "outbound_group", select both outbound tag and outbound group tags. +// * If FieldSelectors is left empty, all fields will be returned. +message SubscribeRoutingStatsRequest { + repeated string FieldSelectors = 1; +} + +// TestRouteRequest manually tests a routing result according to the routing context message. +// * RoutingContext is the routing message without outbound information. +// * FieldSelectors selects the fields to return in the routing result. All fields are returned if left empty. +// * PublishResult broadcasts the routing result to routing statistics channel if set true. +message TestRouteRequest { + RoutingContext RoutingContext = 1; + repeated string FieldSelectors = 2; + bool PublishResult = 3; +} + +service RoutingService { + rpc SubscribeRoutingStats(SubscribeRoutingStatsRequest) returns (stream RoutingContext) {} + rpc TestRoute(TestRouteRequest) returns (RoutingContext) {} +} + +message Config {} diff --git a/app/router/command/command_grpc.pb.go b/app/router/command/command_grpc.pb.go new file mode 100644 index 00000000..7b51b2cc --- /dev/null +++ b/app/router/command/command_grpc.pb.go @@ -0,0 +1,154 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package command + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion6 + +// RoutingServiceClient is the client API for RoutingService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type RoutingServiceClient interface { + SubscribeRoutingStats(ctx context.Context, in *SubscribeRoutingStatsRequest, opts ...grpc.CallOption) (RoutingService_SubscribeRoutingStatsClient, error) + TestRoute(ctx context.Context, in *TestRouteRequest, opts ...grpc.CallOption) (*RoutingContext, error) +} + +type routingServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewRoutingServiceClient(cc grpc.ClientConnInterface) RoutingServiceClient { + return &routingServiceClient{cc} +} + +func (c *routingServiceClient) SubscribeRoutingStats(ctx context.Context, in *SubscribeRoutingStatsRequest, opts ...grpc.CallOption) (RoutingService_SubscribeRoutingStatsClient, error) { + stream, err := c.cc.NewStream(ctx, &_RoutingService_serviceDesc.Streams[0], "/v2ray.core.app.router.command.RoutingService/SubscribeRoutingStats", opts...) + if err != nil { + return nil, err + } + x := &routingServiceSubscribeRoutingStatsClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type RoutingService_SubscribeRoutingStatsClient interface { + Recv() (*RoutingContext, error) + grpc.ClientStream +} + +type routingServiceSubscribeRoutingStatsClient struct { + grpc.ClientStream +} + +func (x *routingServiceSubscribeRoutingStatsClient) Recv() (*RoutingContext, error) { + m := new(RoutingContext) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *routingServiceClient) TestRoute(ctx context.Context, in *TestRouteRequest, opts ...grpc.CallOption) (*RoutingContext, error) { + out := new(RoutingContext) + err := c.cc.Invoke(ctx, "/v2ray.core.app.router.command.RoutingService/TestRoute", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// RoutingServiceServer is the server API for RoutingService service. +// All implementations must embed UnimplementedRoutingServiceServer +// for forward compatibility +type RoutingServiceServer interface { + SubscribeRoutingStats(*SubscribeRoutingStatsRequest, RoutingService_SubscribeRoutingStatsServer) error + TestRoute(context.Context, *TestRouteRequest) (*RoutingContext, error) + mustEmbedUnimplementedRoutingServiceServer() +} + +// UnimplementedRoutingServiceServer must be embedded to have forward compatible implementations. +type UnimplementedRoutingServiceServer struct { +} + +func (*UnimplementedRoutingServiceServer) SubscribeRoutingStats(*SubscribeRoutingStatsRequest, RoutingService_SubscribeRoutingStatsServer) error { + return status.Errorf(codes.Unimplemented, "method SubscribeRoutingStats not implemented") +} +func (*UnimplementedRoutingServiceServer) TestRoute(context.Context, *TestRouteRequest) (*RoutingContext, error) { + return nil, status.Errorf(codes.Unimplemented, "method TestRoute not implemented") +} +func (*UnimplementedRoutingServiceServer) mustEmbedUnimplementedRoutingServiceServer() {} + +func RegisterRoutingServiceServer(s *grpc.Server, srv RoutingServiceServer) { + s.RegisterService(&_RoutingService_serviceDesc, srv) +} + +func _RoutingService_SubscribeRoutingStats_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(SubscribeRoutingStatsRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(RoutingServiceServer).SubscribeRoutingStats(m, &routingServiceSubscribeRoutingStatsServer{stream}) +} + +type RoutingService_SubscribeRoutingStatsServer interface { + Send(*RoutingContext) error + grpc.ServerStream +} + +type routingServiceSubscribeRoutingStatsServer struct { + grpc.ServerStream +} + +func (x *routingServiceSubscribeRoutingStatsServer) Send(m *RoutingContext) error { + return x.ServerStream.SendMsg(m) +} + +func _RoutingService_TestRoute_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestRouteRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(RoutingServiceServer).TestRoute(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/v2ray.core.app.router.command.RoutingService/TestRoute", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(RoutingServiceServer).TestRoute(ctx, req.(*TestRouteRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _RoutingService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "v2ray.core.app.router.command.RoutingService", + HandlerType: (*RoutingServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "TestRoute", + Handler: _RoutingService_TestRoute_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "SubscribeRoutingStats", + Handler: _RoutingService_SubscribeRoutingStats_Handler, + ServerStreams: true, + }, + }, + Metadata: "app/router/command/command.proto", +} diff --git a/app/router/command/command_test.go b/app/router/command/command_test.go new file mode 100644 index 00000000..d9fcf585 --- /dev/null +++ b/app/router/command/command_test.go @@ -0,0 +1,334 @@ +package command_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" + "v2ray.com/core/app/router" + . "v2ray.com/core/app/router/command" + "v2ray.com/core/app/stats" + "v2ray.com/core/common" + "v2ray.com/core/common/net" + "v2ray.com/core/features/routing" + "v2ray.com/core/testing/mocks" +) + +func TestServiceSubscribeRoutingStats(t *testing.T) { + c := stats.NewChannel(&stats.ChannelConfig{ + SubscriberLimit: 1, + BufferSize: 16, + BroadcastTimeout: 100, + }) + common.Must(c.Start()) + defer c.Close() + + lis := bufconn.Listen(1024 * 1024) + bufDialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + testCases := []*RoutingContext{ + {InboundTag: "in", OutboundTag: "out"}, + {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"}, + {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"}, + {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"}, + {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"}, + {Protocol: "bittorrent", OutboundTag: "blocked"}, + {User: "example@v2fly.org", OutboundTag: "out"}, + {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"}, + } + errCh := make(chan error) + nextPub := make(chan struct{}) + + // Server goroutine + go func() { + server := grpc.NewServer() + RegisterRoutingServiceServer(server, NewRoutingServer(nil, c)) + errCh <- server.Serve(lis) + }() + + // Publisher goroutine + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for { // Wait until there's one subscriber in routing stats channel + if len(c.Subscribers()) > 0 { + break + } + if ctx.Err() != nil { + errCh <- ctx.Err() + } + } + for _, tc := range testCases { + c.Publish(AsRoutingRoute(tc)) + } + + // Wait for next round of publishing + <-nextPub + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + for { // Wait until there's one subscriber in routing stats channel + if len(c.Subscribers()) > 0 { + break + } + if ctx.Err() != nil { + errCh <- ctx.Err() + } + } + for _, tc := range testCases { + c.Publish(AsRoutingRoute(tc)) + } + }() + + // Client goroutine + go func() { + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + errCh <- err + } + defer lis.Close() + defer conn.Close() + client := NewRoutingServiceClient(conn) + + // Test retrieving all fields + streamCtx, streamClose := context.WithCancel(context.Background()) + stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{}) + if err != nil { + errCh <- err + } + + for _, tc := range testCases { + msg, err := stream.Recv() + if err != nil { + errCh <- err + } + if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } + + // Test that double subscription will fail + errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{ + FieldSelectors: []string{"ip", "port", "domain", "outbound"}, + }) + if err != nil { + errCh <- err + } + if _, err := errStream.Recv(); err == nil { + t.Error("unexpected successful subscription") + } + + // Test the unsubscription of stream works well + streamClose() + timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second) + defer timeout() + for { // Wait until there's no subscriber in routing stats channel + if len(c.Subscribers()) == 0 { + break + } + if timeOutCtx.Err() != nil { + t.Error("unexpected subscribers not decreased in channel") + errCh <- timeOutCtx.Err() + } + } + + // Test retrieving only a subset of fields + streamCtx, streamClose = context.WithCancel(context.Background()) + stream, err = client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{ + FieldSelectors: []string{"ip", "port", "domain", "outbound"}, + }) + if err != nil { + errCh <- err + } + + close(nextPub) // Send nextPub signal to start next round of publishing + for _, tc := range testCases { + msg, err := stream.Recv() + stat := &RoutingContext{ // Only a subset of stats is retrieved + SourceIPs: tc.SourceIPs, + TargetIPs: tc.TargetIPs, + SourcePort: tc.SourcePort, + TargetPort: tc.TargetPort, + TargetDomain: tc.TargetDomain, + OutboundGroupTags: tc.OutboundGroupTags, + OutboundTag: tc.OutboundTag, + } + if err != nil { + errCh <- err + } + if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } + streamClose() + + // Client passed all tests successfully + errCh <- nil + }() + + // Wait for goroutines to complete + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} + +func TestSerivceTestRoute(t *testing.T) { + c := stats.NewChannel(&stats.ChannelConfig{ + SubscriberLimit: 1, + BufferSize: 16, + BroadcastTimeout: 100, + }) + common.Must(c.Start()) + defer c.Close() + + r := new(router.Router) + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + common.Must(r.Init(&router.Config{ + Rule: []*router.RoutingRule{ + { + InboundTag: []string{"in"}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + { + Protocol: []string{"bittorrent"}, + TargetTag: &router.RoutingRule_Tag{Tag: "blocked"}, + }, + { + PortList: &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + { + SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + { + Domain: []*router.Domain{{Type: router.Domain_Domain, Value: "com"}}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + { + SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + { + UserEmail: []string{"example@v2fly.org"}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + { + Networks: []net.Network{net.Network_UDP, net.Network_TCP}, + TargetTag: &router.RoutingRule_Tag{Tag: "out"}, + }, + }, + }, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl))) + + lis := bufconn.Listen(1024 * 1024) + bufDialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + errCh := make(chan error) + + // Server goroutine + go func() { + server := grpc.NewServer() + RegisterRoutingServiceServer(server, NewRoutingServer(r, c)) + errCh <- server.Serve(lis) + }() + + // Client goroutine + go func() { + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + errCh <- err + } + defer lis.Close() + defer conn.Close() + client := NewRoutingServiceClient(conn) + + testCases := []*RoutingContext{ + {InboundTag: "in", OutboundTag: "out"}, + {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"}, + {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"}, + {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"}, + {Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"}, + {User: "example@v2fly.org", OutboundTag: "out"}, + {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"}, + } + + // Test simple TestRoute + for _, tc := range testCases { + route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc}) + if err != nil { + errCh <- err + } + if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } + + // Test TestRoute with special options + sub, err := c.Subscribe() + if err != nil { + errCh <- err + } + for _, tc := range testCases { + route, err := client.TestRoute(context.Background(), &TestRouteRequest{ + RoutingContext: tc, + FieldSelectors: []string{"ip", "port", "domain", "outbound"}, + PublishResult: true, + }) + stat := &RoutingContext{ // Only a subset of stats is retrieved + SourceIPs: tc.SourceIPs, + TargetIPs: tc.TargetIPs, + SourcePort: tc.SourcePort, + TargetPort: tc.TargetPort, + TargetDomain: tc.TargetDomain, + OutboundGroupTags: tc.OutboundGroupTags, + OutboundTag: tc.OutboundTag, + } + if err != nil { + errCh <- err + } + if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + select { // Check that routing result has been published to statistics channel + case msg, received := <-sub: + if route, ok := msg.(routing.Route); received && ok { + if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } else { + t.Error("unexpected failure in receiving published routing result") + } + case <-time.After(100 * time.Millisecond): + t.Error("unexpected failure in receiving published routing result") + } + } + + // Client passed all tests successfully + errCh <- nil + }() + + // Wait for goroutines to complete + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} diff --git a/app/router/command/config.go b/app/router/command/config.go new file mode 100644 index 00000000..1385f296 --- /dev/null +++ b/app/router/command/config.go @@ -0,0 +1,94 @@ +package command + +import ( + "strings" + + "v2ray.com/core/common/net" + "v2ray.com/core/features/routing" +) + +// routingContext is an wrapper of protobuf RoutingContext as implementation of routing.Context and routing.Route. +type routingContext struct { + *RoutingContext +} + +func (c routingContext) GetSourceIPs() []net.IP { + return mapBytesToIPs(c.RoutingContext.GetSourceIPs()) +} + +func (c routingContext) GetSourcePort() net.Port { + return net.Port(c.RoutingContext.GetSourcePort()) +} + +func (c routingContext) GetTargetIPs() []net.IP { + return mapBytesToIPs(c.RoutingContext.GetTargetIPs()) +} + +func (c routingContext) GetTargetPort() net.Port { + return net.Port(c.RoutingContext.GetTargetPort()) +} + +// AsRoutingContext converts a protobuf RoutingContext into an implementation of routing.Context. +func AsRoutingContext(r *RoutingContext) routing.Context { + return routingContext{r} +} + +// AsRoutingRoute converts a protobuf RoutingContext into an implementation of routing.Route. +func AsRoutingRoute(r *RoutingContext) routing.Route { + return routingContext{r} +} + +var fieldMap = map[string]func(*RoutingContext, routing.Route){ + "inbound": func(s *RoutingContext, r routing.Route) { s.InboundTag = r.GetInboundTag() }, + "network": func(s *RoutingContext, r routing.Route) { s.Network = r.GetNetwork() }, + "ip_source": func(s *RoutingContext, r routing.Route) { s.SourceIPs = mapIPsToBytes(r.GetSourceIPs()) }, + "ip_target": func(s *RoutingContext, r routing.Route) { s.TargetIPs = mapIPsToBytes(r.GetTargetIPs()) }, + "port_source": func(s *RoutingContext, r routing.Route) { s.SourcePort = uint32(r.GetSourcePort()) }, + "port_target": func(s *RoutingContext, r routing.Route) { s.TargetPort = uint32(r.GetTargetPort()) }, + "domain": func(s *RoutingContext, r routing.Route) { s.TargetDomain = r.GetTargetDomain() }, + "protocol": func(s *RoutingContext, r routing.Route) { s.Protocol = r.GetProtocol() }, + "user": func(s *RoutingContext, r routing.Route) { s.User = r.GetUser() }, + "attributes": func(s *RoutingContext, r routing.Route) { s.Attributes = r.GetAttributes() }, + "outbound_group": func(s *RoutingContext, r routing.Route) { s.OutboundGroupTags = r.GetOutboundGroupTags() }, + "outbound": func(s *RoutingContext, r routing.Route) { s.OutboundTag = r.GetOutboundTag() }, +} + +// AsProtobufMessage takes selectors of fields and returns a function to convert routing.Route to protobuf RoutingContext. +func AsProtobufMessage(fieldSelectors []string) func(routing.Route) *RoutingContext { + initializers := []func(*RoutingContext, routing.Route){} + for field, init := range fieldMap { + if len(fieldSelectors) == 0 { // If selectors not set, retrieve all fields + initializers = append(initializers, init) + continue + } + for _, selector := range fieldSelectors { + if strings.HasPrefix(field, selector) { + initializers = append(initializers, init) + break + } + } + } + return func(ctx routing.Route) *RoutingContext { + message := new(RoutingContext) + for _, init := range initializers { + init(message, ctx) + } + return message + } +} + +func mapBytesToIPs(bytes [][]byte) []net.IP { + var ips []net.IP + for _, rawIP := range bytes { + ips = append(ips, net.IP(rawIP)) + } + return ips +} + +func mapIPsToBytes(ips []net.IP) [][]byte { + var bytes [][]byte + for _, ip := range ips { + bytes = append(bytes, []byte(ip)) + } + return bytes +} diff --git a/app/router/command/errors.generated.go b/app/router/command/errors.generated.go new file mode 100644 index 00000000..66f78051 --- /dev/null +++ b/app/router/command/errors.generated.go @@ -0,0 +1,9 @@ +package command + +import "v2ray.com/core/common/errors" + +type errPathObjHolder struct{} + +func newError(values ...interface{}) *errors.Error { + return errors.New(values...).WithPathObj(errPathObjHolder{}) +} diff --git a/app/router/router.go b/app/router/router.go index 7e04c554..de8321ba 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -9,24 +9,12 @@ import ( "v2ray.com/core" "v2ray.com/core/common" - "v2ray.com/core/common/net" "v2ray.com/core/features/dns" "v2ray.com/core/features/outbound" "v2ray.com/core/features/routing" + routing_dns "v2ray.com/core/features/routing/dns" ) -func init() { - common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { - r := new(Router) - if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager) error { - return r.Init(config.(*Config), d, ohm) - }); err != nil { - return nil, err - } - return r, nil - })) -} - // Router is an implementation of routing.Router. type Router struct { domainStrategy Config_DomainStrategy @@ -35,6 +23,13 @@ type Router struct { dns dns.Client } +// Route is an implementation of routing.Route. +type Route struct { + routing.Context + outboundGroupTags []string + outboundTag string +} + // Init initializes the Router. func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error { r.domainStrategy = config.DomainStrategy @@ -74,39 +69,43 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error } // PickRoute implements routing.Router. -func (r *Router) PickRoute(ctx routing.Context) (string, error) { - rule, err := r.pickRouteInternal(ctx) +func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) { + rule, ctx, err := r.pickRouteInternal(ctx) if err != nil { - return "", err + return nil, err } - return rule.GetTag() + tag, err := rule.GetTag() + if err != nil { + return nil, err + } + return &Route{Context: ctx, outboundTag: tag}, nil } -func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, error) { +func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) { if r.domainStrategy == Config_IpOnDemand { - ctx = ContextWithDNSClient(ctx, r.dns) + ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) } for _, rule := range r.rules { if rule.Apply(ctx) { - return rule, nil + return rule, ctx, nil } } if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 { - return nil, common.ErrNoClue + return nil, ctx, common.ErrNoClue } - ctx = ContextWithDNSClient(ctx, r.dns) + ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) // Try applying rules again if we have IPs. for _, rule := range r.rules { if rule.Apply(ctx) { - return rule, nil + return rule, ctx, nil } } - return nil, common.ErrNoClue + return nil, ctx, common.ErrNoClue } // Start implements common.Runnable. @@ -124,34 +123,24 @@ func (*Router) Type() interface{} { return routing.RouterType() } -// ContextWithDNSClient creates a new routing context with domain resolving capability. Resolved domain IPs can be retrieved by GetTargetIPs(). -func ContextWithDNSClient(ctx routing.Context, client dns.Client) routing.Context { - return &resolvableContext{Context: ctx, dnsClient: client} +// GetOutboundGroupTags implements routing.Route. +func (r *Route) GetOutboundGroupTags() []string { + return r.outboundGroupTags } -type resolvableContext struct { - routing.Context - dnsClient dns.Client - resolvedIPs []net.IP +// GetOutboundTag implements routing.Route. +func (r *Route) GetOutboundTag() string { + return r.outboundTag } -func (ctx *resolvableContext) GetTargetIPs() []net.IP { - if ips := ctx.Context.GetTargetIPs(); len(ips) != 0 { - return ips - } - - if len(ctx.resolvedIPs) > 0 { - return ctx.resolvedIPs - } - - if domain := ctx.GetTargetDomain(); len(domain) != 0 { - ips, err := ctx.dnsClient.LookupIP(domain) - if err == nil { - ctx.resolvedIPs = ips - return ips +func init() { + common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { + r := new(Router) + if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager) error { + return r.Init(config.(*Config), d, ohm) + }); err != nil { + return nil, err } - newError("resolve ip for ", domain).Base(err).WriteToLog() - } - - return nil + return r, nil + })) } diff --git a/app/router/router_test.go b/app/router/router_test.go index 0ed5f033..8c1aec0a 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -45,9 +45,9 @@ func TestSimpleRouter(t *testing.T) { })) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) + route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) - if tag != "test" { + if tag := route.GetOutboundTag(); tag != "test" { t.Error("expect tag 'test', bug actually ", tag) } } @@ -86,9 +86,9 @@ func TestSimpleBalancer(t *testing.T) { })) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) + route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) - if tag != "test" { + if tag := route.GetOutboundTag(); tag != "test" { t.Error("expect tag 'test', bug actually ", tag) } } @@ -121,9 +121,9 @@ func TestIPOnDemand(t *testing.T) { common.Must(r.Init(config, mockDns, nil)) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) + route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) - if tag != "test" { + if tag := route.GetOutboundTag(); tag != "test" { t.Error("expect tag 'test', bug actually ", tag) } } @@ -156,9 +156,9 @@ func TestIPIfNonMatchDomain(t *testing.T) { common.Must(r.Init(config, mockDns, nil)) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) + route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) - if tag != "test" { + if tag := route.GetOutboundTag(); tag != "test" { t.Error("expect tag 'test', bug actually ", tag) } } @@ -190,9 +190,9 @@ func TestIPIfNonMatchIP(t *testing.T) { common.Must(r.Init(config, mockDns, nil)) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) - tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) + route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) - if tag != "test" { + if tag := route.GetOutboundTag(); tag != "test" { t.Error("expect tag 'test', bug actually ", tag) } } diff --git a/app/stats/channel.go b/app/stats/channel.go new file mode 100644 index 00000000..dd484fab --- /dev/null +++ b/app/stats/channel.go @@ -0,0 +1,144 @@ +// +build !confonly + +package stats + +import ( + "sync" + "time" + + "v2ray.com/core/common" +) + +// Channel is an implementation of stats.Channel. +type Channel struct { + channel chan interface{} + subscribers []chan interface{} + + // Synchronization components + access sync.RWMutex + closed chan struct{} + + // Channel options + subscriberLimit int // Set to 0 as no subscriber limit + channelBufferSize int // Set to 0 as no buffering + broadcastTimeout time.Duration // Set to 0 as non-blocking immediate timeout +} + +// NewChannel creates an instance of Statistics Channel. +func NewChannel(config *ChannelConfig) *Channel { + return &Channel{ + channel: make(chan interface{}, config.BufferSize), + subscriberLimit: int(config.SubscriberLimit), + channelBufferSize: int(config.BufferSize), + broadcastTimeout: time.Duration(config.BroadcastTimeout+1) * time.Millisecond, + } +} + +// Channel returns the underlying go channel. +func (c *Channel) Channel() chan interface{} { + c.access.RLock() + defer c.access.RUnlock() + return c.channel +} + +// Subscribers implements stats.Channel. +func (c *Channel) Subscribers() []chan interface{} { + c.access.RLock() + defer c.access.RUnlock() + return c.subscribers +} + +// Subscribe implements stats.Channel. +func (c *Channel) Subscribe() (chan interface{}, error) { + c.access.Lock() + defer c.access.Unlock() + if c.subscriberLimit > 0 && len(c.subscribers) >= c.subscriberLimit { + return nil, newError("Number of subscribers has reached limit") + } + subscriber := make(chan interface{}, c.channelBufferSize) + c.subscribers = append(c.subscribers, subscriber) + return subscriber, nil +} + +// Unsubscribe implements stats.Channel. +func (c *Channel) Unsubscribe(subscriber chan interface{}) error { + c.access.Lock() + defer c.access.Unlock() + for i, s := range c.subscribers { + if s == subscriber { + // Copy to new memory block to prevent modifying original data + subscribers := make([]chan interface{}, len(c.subscribers)-1) + copy(subscribers[:i], c.subscribers[:i]) + copy(subscribers[i:], c.subscribers[i+1:]) + c.subscribers = subscribers + } + } + return nil +} + +// Publish implements stats.Channel. +func (c *Channel) Publish(message interface{}) { + select { // Early exit if channel closed + case <-c.closed: + return + default: + } + select { // Drop message if not successfully sent + case c.channel <- message: + default: + return + } +} + +// Running returns whether the channel is running. +func (c *Channel) Running() bool { + select { + case <-c.closed: // Channel closed + default: // Channel running or not initialized + if c.closed != nil { // Channel initialized + return true + } + } + return false +} + +// Start implements common.Runnable. +func (c *Channel) Start() error { + c.access.Lock() + defer c.access.Unlock() + if !c.Running() { + c.closed = make(chan struct{}) // Reset close signal + go func() { + for { + select { + case message := <-c.channel: // Broadcast message + for _, sub := range c.Subscribers() { // Concurrency-safe subscribers retreivement + select { + case sub <- message: // Successfully sent message + case <-time.After(c.broadcastTimeout): // Remove timeout subscriber + common.Must(c.Unsubscribe(sub)) + close(sub) // Actively close subscriber as notification + } + } + case <-c.closed: // Channel closed + for _, sub := range c.Subscribers() { // Remove all subscribers + common.Must(c.Unsubscribe(sub)) + close(sub) + } + return + } + } + }() + } + return nil +} + +// Close implements common.Closable. +func (c *Channel) Close() error { + c.access.Lock() + defer c.access.Unlock() + if c.Running() { + close(c.closed) // Send closed signal + } + return nil +} diff --git a/app/stats/channel_test.go b/app/stats/channel_test.go new file mode 100644 index 00000000..6458711b --- /dev/null +++ b/app/stats/channel_test.go @@ -0,0 +1,350 @@ +package stats_test + +import ( + "fmt" + "testing" + "time" + + . "v2ray.com/core/app/stats" + "v2ray.com/core/common" + "v2ray.com/core/features/stats" +) + +func TestStatsChannel(t *testing.T) { + // At most 2 subscribers could be registered + c := NewChannel(&ChannelConfig{SubscriberLimit: 2}) + source := c.Channel() + + a, err := stats.SubscribeRunnableChannel(c) + common.Must(err) + if !c.Running() { + t.Fatal("unexpected failure in running channel after first subscription") + } + + b, err := c.Subscribe() + common.Must(err) + + // Test that third subscriber is forbidden + _, err = c.Subscribe() + if err == nil { + t.Fatal("unexpected successful subscription") + } + t.Log("expected error: ", err) + + stopCh := make(chan struct{}) + errCh := make(chan string) + + go func() { // Blocking publish + source <- 1 + source <- 2 + source <- "3" + source <- []int{4} + source <- nil // Dummy messsage with no subscriber receiving, will block reading goroutine + for i := 0; i < cap(source); i++ { + source <- nil // Fill source channel's buffer + } + select { + case source <- nil: // Source writing should be blocked here, for last message was not cleared and buffer was full + errCh <- fmt.Sprint("unexpected non-blocked source channel") + default: + close(stopCh) + } + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + if v, ok := (<-a).(string); !ok || v != "3" { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", "3") + } + if v, ok := (<-a).([]int); !ok || v[0] != 4 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) + } + }() + + go func() { + if v, ok := (<-b).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-b).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + if v, ok := (<-b).(string); !ok || v != "3" { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", "3") + } + if v, ok := (<-b).([]int); !ok || v[0] != 4 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) + } + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } + + // Test the unsubscription of channel + common.Must(c.Unsubscribe(b)) + + // Test the last subscriber will close channel with `UnsubscribeClosableChannel` + common.Must(stats.UnsubscribeClosableChannel(c, a)) + if c.Running() { + t.Fatal("unexpected running channel after unsubscribing the last subscriber") + } +} + +func TestStatsChannelUnsubcribe(t *testing.T) { + c := NewChannel(&ChannelConfig{}) + common.Must(c.Start()) + defer c.Close() + + source := c.Channel() + + a, err := c.Subscribe() + common.Must(err) + defer c.Unsubscribe(a) + + b, err := c.Subscribe() + common.Must(err) + + pauseCh := make(chan struct{}) + stopCh := make(chan struct{}) + errCh := make(chan string) + + { + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && bSet) { + t.Fatal("unexpected subscribers: ", c.Subscribers()) + } + } + + go func() { // Blocking publish + source <- 1 + <-pauseCh // Wait for `b` goroutine to resume sending message + source <- 2 + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + }() + + go func() { + if v, ok := (<-b).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + // Unsubscribe `b` while `source`'s messaging is paused + c.Unsubscribe(b) + { // Test `b` is not in subscribers + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && !bSet) { + errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + } + } + // Resume `source`'s progress + close(pauseCh) + // Test `b` is neither closed nor able to receive any data + select { + case v, ok := <-b: + if ok { + errCh <- fmt.Sprint("unexpected data received: ", v) + } else { + errCh <- fmt.Sprint("unexpected closed channel: ", b) + } + default: + } + close(stopCh) + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} + +func TestStatsChannelTimeout(t *testing.T) { + // Do not use buffer so as to create blocking scenario + c := NewChannel(&ChannelConfig{BufferSize: 0, BroadcastTimeout: 50}) + common.Must(c.Start()) + defer c.Close() + + source := c.Channel() + + a, err := c.Subscribe() + common.Must(err) + defer c.Unsubscribe(a) + + b, err := c.Subscribe() + common.Must(err) + defer c.Unsubscribe(b) + + stopCh := make(chan struct{}) + errCh := make(chan string) + + go func() { // Blocking publish + source <- 1 + source <- 2 + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + { // Test `b` is still in subscribers yet (because `a` receives 2 first) + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && bSet) { + errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + } + } + }() + + go func() { + if v, ok := (<-b).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + // Block `b` channel for a time longer than `source`'s timeout + <-time.After(200 * time.Millisecond) + { // Test `b` has been unsubscribed by source + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && !bSet) { + errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + } + } + select { // Test `b` has been closed by source + case v, ok := <-b: + if ok { + errCh <- fmt.Sprint("unexpected data received: ", v) + } + default: + } + close(stopCh) + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} + +func TestStatsChannelConcurrency(t *testing.T) { + // Do not use buffer so as to create blocking scenario + c := NewChannel(&ChannelConfig{BufferSize: 0, BroadcastTimeout: 100}) + common.Must(c.Start()) + defer c.Close() + + source := c.Channel() + + a, err := c.Subscribe() + common.Must(err) + defer c.Unsubscribe(a) + + b, err := c.Subscribe() + common.Must(err) + defer c.Unsubscribe(b) + + stopCh := make(chan struct{}) + errCh := make(chan string) + + go func() { // Blocking publish + source <- 1 + source <- 2 + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + }() + + go func() { + // Block `b` for a time shorter than `source`'s timeout + // So as to ensure source channel is trying to send message to `b`. + <-time.After(25 * time.Millisecond) + // This causes concurrency scenario: unsubscribe `b` while trying to send message to it + c.Unsubscribe(b) + // Test `b` is not closed and can still receive data 1: + // Because unsubscribe won't affect the ongoing process of sending message. + select { + case v, ok := <-b: + if v1, ok1 := v.(int); !(ok && ok1 && v1 == 1) { + errCh <- fmt.Sprint("unexpected failure in receiving data: ", 1) + } + default: + errCh <- fmt.Sprint("unexpected block from receiving data: ", 1) + } + // Test `b` is not closed but cannot receive data 2: + // Becuase in a new round of messaging, `b` has been unsubscribed. + select { + case v, ok := <-b: + if ok { + errCh <- fmt.Sprint("unexpected receving: ", v) + } else { + errCh <- fmt.Sprint("unexpected closing of channel") + } + default: + } + close(stopCh) + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} diff --git a/app/stats/config.go b/app/stats/config.go deleted file mode 100644 index e124b17a..00000000 --- a/app/stats/config.go +++ /dev/null @@ -1,15 +0,0 @@ -// +build !confonly - -package stats - -import ( - "context" - - "v2ray.com/core/common" -) - -func init() { - common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { - return NewManager(ctx, config.(*Config)) - })) -} diff --git a/app/stats/config.pb.go b/app/stats/config.pb.go index f430b641..f9402fc7 100644 --- a/app/stats/config.pb.go +++ b/app/stats/config.pb.go @@ -63,18 +63,90 @@ func (*Config) Descriptor() ([]byte, []int) { return file_app_stats_config_proto_rawDescGZIP(), []int{0} } +type ChannelConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SubscriberLimit int32 `protobuf:"varint,1,opt,name=SubscriberLimit,proto3" json:"SubscriberLimit,omitempty"` + BufferSize int32 `protobuf:"varint,2,opt,name=BufferSize,proto3" json:"BufferSize,omitempty"` + BroadcastTimeout int32 `protobuf:"varint,3,opt,name=BroadcastTimeout,proto3" json:"BroadcastTimeout,omitempty"` +} + +func (x *ChannelConfig) Reset() { + *x = ChannelConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_app_stats_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ChannelConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChannelConfig) ProtoMessage() {} + +func (x *ChannelConfig) ProtoReflect() protoreflect.Message { + mi := &file_app_stats_config_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChannelConfig.ProtoReflect.Descriptor instead. +func (*ChannelConfig) Descriptor() ([]byte, []int) { + return file_app_stats_config_proto_rawDescGZIP(), []int{1} +} + +func (x *ChannelConfig) GetSubscriberLimit() int32 { + if x != nil { + return x.SubscriberLimit + } + return 0 +} + +func (x *ChannelConfig) GetBufferSize() int32 { + if x != nil { + return x.BufferSize + } + return 0 +} + +func (x *ChannelConfig) GetBroadcastTimeout() int32 { + if x != nil { + return x.BroadcastTimeout + } + return 0 +} + var File_app_stats_config_proto protoreflect.FileDescriptor var file_app_stats_config_proto_rawDesc = []byte{ 0x0a, 0x16, 0x61, 0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x22, 0x08, - 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x42, 0x4d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, - 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, - 0x74, 0x61, 0x74, 0x73, 0x50, 0x01, 0x5a, 0x18, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, - 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, - 0x70, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x85, 0x01, 0x0a, 0x0d, 0x43, 0x68, 0x61, + 0x6e, 0x6e, 0x65, 0x6c, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x53, 0x75, + 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x0f, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, + 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69, + 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, + 0x53, 0x69, 0x7a, 0x65, 0x12, 0x2a, 0x0a, 0x10, 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, + 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x10, + 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, + 0x42, 0x4d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, + 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x50, 0x01, 0x5a, 0x18, + 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, + 0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, + 0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x73, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -89,9 +161,10 @@ func file_app_stats_config_proto_rawDescGZIP() []byte { return file_app_stats_config_proto_rawDescData } -var file_app_stats_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_app_stats_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_app_stats_config_proto_goTypes = []interface{}{ - (*Config)(nil), // 0: v2ray.core.app.stats.Config + (*Config)(nil), // 0: v2ray.core.app.stats.Config + (*ChannelConfig)(nil), // 1: v2ray.core.app.stats.ChannelConfig } var file_app_stats_config_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -119,6 +192,18 @@ func file_app_stats_config_proto_init() { return nil } } + file_app_stats_config_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ChannelConfig); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -126,7 +211,7 @@ func file_app_stats_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_app_stats_config_proto_rawDesc, NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/app/stats/config.proto b/app/stats/config.proto index bcbf847f..0ea911fd 100644 --- a/app/stats/config.proto +++ b/app/stats/config.proto @@ -7,5 +7,11 @@ option java_package = "com.v2ray.core.app.stats"; option java_multiple_files = true; message Config { - + +} + +message ChannelConfig { + int32 SubscriberLimit = 1; + int32 BufferSize = 2; + int32 BroadcastTimeout = 3; } diff --git a/app/stats/counter.go b/app/stats/counter.go new file mode 100644 index 00000000..c4e12013 --- /dev/null +++ b/app/stats/counter.go @@ -0,0 +1,25 @@ +// +build !confonly + +package stats + +import "sync/atomic" + +// Counter is an implementation of stats.Counter. +type Counter struct { + value int64 +} + +// Value implements stats.Counter. +func (c *Counter) Value() int64 { + return atomic.LoadInt64(&c.value) +} + +// Set implements stats.Counter. +func (c *Counter) Set(newValue int64) int64 { + return atomic.SwapInt64(&c.value, newValue) +} + +// Add implements stats.Counter. +func (c *Counter) Add(delta int64) int64 { + return atomic.AddInt64(&c.value, delta) +} diff --git a/app/stats/counter_test.go b/app/stats/counter_test.go new file mode 100644 index 00000000..f2594e1e --- /dev/null +++ b/app/stats/counter_test.go @@ -0,0 +1,31 @@ +package stats_test + +import ( + "context" + "testing" + + . "v2ray.com/core/app/stats" + "v2ray.com/core/common" + "v2ray.com/core/features/stats" +) + +func TestStatsCounter(t *testing.T) { + raw, err := common.CreateObject(context.Background(), &Config{}) + common.Must(err) + + m := raw.(stats.Manager) + c, err := m.RegisterCounter("test.counter") + common.Must(err) + + if v := c.Add(1); v != 1 { + t.Fatal("unpexcted Add(1) return: ", v, ", wanted ", 1) + } + + if v := c.Set(0); v != 1 { + t.Fatal("unexpected Set(0) return: ", v, ", wanted ", 1) + } + + if v := c.Value(); v != 0 { + t.Fatal("unexpected Value() return: ", v, ", wanted ", 0) + } +} diff --git a/app/stats/stats.go b/app/stats/stats.go index be45a306..1156fcae 100644 --- a/app/stats/stats.go +++ b/app/stats/stats.go @@ -7,98 +7,21 @@ package stats import ( "context" "sync" - "sync/atomic" - "time" + "v2ray.com/core/common" + "v2ray.com/core/common/errors" "v2ray.com/core/features/stats" ) -// Counter is an implementation of stats.Counter. -type Counter struct { - value int64 -} - -// Value implements stats.Counter. -func (c *Counter) Value() int64 { - return atomic.LoadInt64(&c.value) -} - -// Set implements stats.Counter. -func (c *Counter) Set(newValue int64) int64 { - return atomic.SwapInt64(&c.value, newValue) -} - -// Add implements stats.Counter. -func (c *Counter) Add(delta int64) int64 { - return atomic.AddInt64(&c.value, delta) -} - -// Channel is an implementation of stats.Channel -type Channel struct { - channel chan interface{} - subscribers []chan interface{} - access sync.RWMutex -} - -// Channel implements stats.Channel -func (c *Channel) Channel() chan interface{} { - return c.channel -} - -// Subscribers implements stats.Channel -func (c *Channel) Subscribers() []chan interface{} { - c.access.RLock() - defer c.access.RUnlock() - return c.subscribers -} - -// Subscribe implements stats.Channel -func (c *Channel) Subscribe() chan interface{} { - c.access.Lock() - defer c.access.Unlock() - ch := make(chan interface{}) - c.subscribers = append(c.subscribers, ch) - return ch -} - -// Unsubscribe implements stats.Channel -func (c *Channel) Unsubscribe(ch chan interface{}) { - c.access.Lock() - defer c.access.Unlock() - for i, s := range c.subscribers { - if s == ch { - // Copy to new memory block to prevent modifying original data - subscribers := make([]chan interface{}, len(c.subscribers)-1) - copy(subscribers[:i], c.subscribers[:i]) - copy(subscribers[i:], c.subscribers[i+1:]) - c.subscribers = subscribers - return - } - } -} - -// Start starts the channel for listening to messsages -func (c *Channel) Start() { - for message := range c.Channel() { - subscribers := c.Subscribers() // Store a copy of slice value for concurrency safety - for _, sub := range subscribers { - select { - case sub <- message: // Successfully sent message - case <-time.After(100 * time.Millisecond): - c.Unsubscribe(sub) // Remove timeout subscriber - close(sub) // Actively close subscriber as notification - } - } - } -} - // Manager is an implementation of stats.Manager. type Manager struct { access sync.RWMutex counters map[string]*Counter channels map[string]*Channel + running bool } +// NewManager creates an instance of Statistics Manager. func NewManager(ctx context.Context, config *Config) (*Manager, error) { m := &Manager{ counters: make(map[string]*Counter), @@ -108,6 +31,7 @@ func NewManager(ctx context.Context, config *Config) (*Manager, error) { return m, nil } +// Type implements common.HasType. func (*Manager) Type() interface{} { return stats.ManagerType() } @@ -170,9 +94,11 @@ func (m *Manager) RegisterChannel(name string) (stats.Channel, error) { return nil, newError("Channel ", name, " already registered.") } newError("create new channel ", name).AtDebug().WriteToLog() - c := &Channel{channel: make(chan interface{})} + c := NewChannel(&ChannelConfig{BufferSize: 16, BroadcastTimeout: 100}) m.channels[name] = c - go c.Start() + if m.running { + return c, c.Start() + } return c, nil } @@ -181,9 +107,10 @@ func (m *Manager) UnregisterChannel(name string) error { m.access.Lock() defer m.access.Unlock() - if _, found := m.channels[name]; found { + if c, found := m.channels[name]; found { newError("remove channel ", name).AtDebug().WriteToLog() delete(m.channels, name) + return c.Close() } return nil } @@ -201,11 +128,42 @@ func (m *Manager) GetChannel(name string) stats.Channel { // Start implements common.Runnable. func (m *Manager) Start() error { + m.access.Lock() + defer m.access.Unlock() + m.running = true + errs := []error{} + for _, channel := range m.channels { + if err := channel.Start(); err != nil { + errs = append(errs, err) + } + } + if len(errs) != 0 { + return errors.Combine(errs...) + } return nil } // Close implement common.Closable. func (m *Manager) Close() error { + m.access.Lock() + defer m.access.Unlock() + m.running = false + errs := []error{} + for name, channel := range m.channels { + newError("remove channel ", name).AtDebug().WriteToLog() + delete(m.channels, name) + if err := channel.Close(); err != nil { + errs = append(errs, err) + } + } + if len(errs) != 0 { + return errors.Combine(errs...) + } return nil } +func init() { + common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { + return NewManager(ctx, config.(*Config)) + })) +} \ No newline at end of file diff --git a/app/stats/stats_test.go b/app/stats/stats_test.go index 0c724257..1641021d 100644 --- a/app/stats/stats_test.go +++ b/app/stats/stats_test.go @@ -2,7 +2,6 @@ package stats_test import ( "context" - "fmt" "testing" "time" @@ -15,337 +14,73 @@ func TestInterface(t *testing.T) { _ = (stats.Manager)(new(Manager)) } -func TestStatsCounter(t *testing.T) { +func TestStatsChannelRunnable(t *testing.T) { raw, err := common.CreateObject(context.Background(), &Config{}) common.Must(err) m := raw.(stats.Manager) - c, err := m.RegisterCounter("test.counter") + + ch1, err := m.RegisterChannel("test.channel.1") + c1 := ch1.(*Channel) common.Must(err) - if v := c.Add(1); v != 1 { - t.Fatal("unpexcted Add(1) return: ", v, ", wanted ", 1) + if c1.Running() { + t.Fatalf("unexpected running channel: test.channel.%d", 1) } - if v := c.Set(0); v != 1 { - t.Fatal("unexpected Set(0) return: ", v, ", wanted ", 1) + common.Must(m.Start()) + + if !c1.Running() { + t.Fatalf("unexpected non-running channel: test.channel.%d", 1) } - if v := c.Value(); v != 0 { - t.Fatal("unexpected Value() return: ", v, ", wanted ", 0) - } -} - -func TestStatsChannel(t *testing.T) { - raw, err := common.CreateObject(context.Background(), &Config{}) - common.Must(err) - - m := raw.(stats.Manager) - c, err := m.RegisterChannel("test.channel") - common.Must(err) - - source := c.Channel() - a := c.Subscribe() - b := c.Subscribe() - defer c.Unsubscribe(a) - defer c.Unsubscribe(b) - - stopCh := make(chan struct{}) - errCh := make(chan string) - - go func() { - source <- 1 - source <- 2 - source <- "3" - source <- []int{4} - source <- nil // Dummy messsage with no subscriber receiving - select { - case source <- nil: // Source should be blocked here, for last message was not cleared - errCh <- fmt.Sprint("unexpected non-blocked source") - default: - close(stopCh) - } - }() - - go func() { - if v, ok := (<-a).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - if v, ok := (<-a).(int); !ok || v != 2 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) - } - if v, ok := (<-a).(string); !ok || v != "3" { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", "3") - } - if v, ok := (<-a).([]int); !ok || v[0] != 4 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) - } - }() - - go func() { - if v, ok := (<-b).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - if v, ok := (<-b).(int); !ok || v != 2 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) - } - if v, ok := (<-b).(string); !ok || v != "3" { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", "3") - } - if v, ok := (<-b).([]int); !ok || v[0] != 4 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) - } - }() - - select { - case <-time.After(2 * time.Second): - t.Fatal("Test timeout after 2s") - case e := <-errCh: - t.Fatal(e) - case <-stopCh: - } -} - -func TestStatsChannelUnsubcribe(t *testing.T) { - raw, err := common.CreateObject(context.Background(), &Config{}) - common.Must(err) - - m := raw.(stats.Manager) - c, err := m.RegisterChannel("test.channel") - common.Must(err) - - source := c.Channel() - a := c.Subscribe() - b := c.Subscribe() - defer c.Unsubscribe(a) - - pauseCh := make(chan struct{}) - stopCh := make(chan struct{}) - errCh := make(chan string) - - { - var aSet, bSet bool - for _, s := range c.Subscribers() { - if s == a { - aSet = true - } - if s == b { - bSet = true - } - } - if !(aSet && bSet) { - t.Fatal("unexpected subscribers: ", c.Subscribers()) - } - } - - go func() { - source <- 1 - <-pauseCh // Wait for `b` goroutine to resume sending message - source <- 2 - }() - - go func() { - if v, ok := (<-a).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - if v, ok := (<-a).(int); !ok || v != 2 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) - } - }() - - go func() { - if v, ok := (<-b).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - // Unsubscribe `b` while `source`'s messaging is paused - c.Unsubscribe(b) - { // Test `b` is not in subscribers - var aSet, bSet bool - for _, s := range c.Subscribers() { - if s == a { - aSet = true - } - if s == b { - bSet = true - } - } - if !(aSet && !bSet) { - errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) - } - } - // Resume `source`'s progress - close(pauseCh) - // Test `b` is neither closed nor able to receive any data - select { - case v, ok := <-b: - if ok { - errCh <- fmt.Sprint("unexpected data received: ", v) - } else { - errCh <- fmt.Sprint("unexpected closed channel: ", b) - } - default: - } - close(stopCh) - }() - - select { - case <-time.After(2 * time.Second): - t.Fatal("Test timeout after 2s") - case e := <-errCh: - t.Fatal(e) - case <-stopCh: - } -} - -func TestStatsChannelTimeout(t *testing.T) { - raw, err := common.CreateObject(context.Background(), &Config{}) - common.Must(err) - - m := raw.(stats.Manager) - c, err := m.RegisterChannel("test.channel") - common.Must(err) - - source := c.Channel() - a := c.Subscribe() - b := c.Subscribe() - defer c.Unsubscribe(a) - defer c.Unsubscribe(b) - - stopCh := make(chan struct{}) - errCh := make(chan string) - - go func() { - source <- 1 - source <- 2 - }() - - go func() { - if v, ok := (<-a).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - if v, ok := (<-a).(int); !ok || v != 2 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) - } - { // Test `b` is still in subscribers yet (because `a` receives 2 first) - var aSet, bSet bool - for _, s := range c.Subscribers() { - if s == a { - aSet = true - } - if s == b { - bSet = true - } - } - if !(aSet && bSet) { - errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) - } - } - }() - - go func() { - if v, ok := (<-b).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - // Block `b` channel for a time longer than `source`'s timeout - <-time.After(150 * time.Millisecond) - { // Test `b` has been unsubscribed by source - var aSet, bSet bool - for _, s := range c.Subscribers() { - if s == a { - aSet = true - } - if s == b { - bSet = true - } - } - if !(aSet && !bSet) { - errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) - } - } - select { // Test `b` has been closed by source - case v, ok := <-b: - if ok { - errCh <- fmt.Sprint("unexpected data received: ", v) - } - default: - } - close(stopCh) - }() - - select { - case <-time.After(2 * time.Second): - t.Fatal("Test timeout after 2s") - case e := <-errCh: - t.Fatal(e) - case <-stopCh: - } -} - -func TestStatsChannelConcurrency(t *testing.T) { - raw, err := common.CreateObject(context.Background(), &Config{}) - common.Must(err) - - m := raw.(stats.Manager) - c, err := m.RegisterChannel("test.channel") - common.Must(err) - - source := c.Channel() - a := c.Subscribe() - b := c.Subscribe() - defer c.Unsubscribe(a) - - stopCh := make(chan struct{}) - errCh := make(chan string) - - go func() { - source <- 1 - source <- 2 - }() - - go func() { - if v, ok := (<-a).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) - } - if v, ok := (<-a).(int); !ok || v != 2 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) - } - }() - - go func() { - // Block `b` for a time shorter than `source`'s timeout - // So as to ensure source channel is trying to send message to `b`. - <-time.After(25 * time.Millisecond) - // This causes concurrency scenario: unsubscribe `b` while trying to send message to it - c.Unsubscribe(b) - // Test `b` is not closed and can still receive data 1: - // Because unsubscribe won't affect the ongoing process of sending message. - select { - case v, ok := <-b: - if v1, ok1 := v.(int); !(ok && ok1 && v1 == 1) { - errCh <- fmt.Sprint("unexpected failure in receiving data: ", 1) - } - default: - errCh <- fmt.Sprint("unexpected block from receiving data: ", 1) - } - // Test `b` is not closed but cannot receive data 2: - // Becuase in a new round of messaging, `b` has been unsubscribed. - select { - case v, ok := <-b: - if ok { - errCh <- fmt.Sprint("unexpected receving: ", v) - } else { - errCh <- fmt.Sprint("unexpected closing of channel") - } - default: - } - close(stopCh) - }() - - select { - case <-time.After(2 * time.Second): - t.Fatal("Test timeout after 2s") - case e := <-errCh: - t.Fatal(e) - case <-stopCh: + ch2, err := m.RegisterChannel("test.channel.2") + c2 := ch2.(*Channel) + common.Must(err) + + if !c2.Running() { + t.Fatalf("unexpected non-running channel: test.channel.%d", 2) + } + + s1, err := c1.Subscribe() + common.Must(err) + common.Must(c1.Close()) + + if c1.Running() { + t.Fatalf("unexpected running channel: test.channel.%d", 1) + } + + select { // Check all subscribers in closed channel are closed + case _, ok := <-s1: + if ok { + t.Fatalf("unexpected non-closed subscriber in channel: test.channel.%d", 1) + } + case <-time.After(500 * time.Millisecond): + t.Fatalf("unexpected non-closed subscriber in channel: test.channel.%d", 1) + } + + if len(c1.Subscribers()) != 0 { // Check subscribers in closed channel are emptied + t.Fatalf("unexpected non-empty subscribers in channel: test.channel.%d", 1) + } + + common.Must(m.Close()) + + if c2.Running() { + t.Fatalf("unexpected running channel: test.channel.%d", 2) + } + + ch3, err := m.RegisterChannel("test.channel.3") + c3 := ch3.(*Channel) + common.Must(err) + + if c3.Running() { + t.Fatalf("unexpected running channel: test.channel.%d", 3) + } + + common.Must(c3.Start()) + common.Must(m.UnregisterChannel("test.channel.3")) + + if c3.Running() { // Test that unregistering will close the channel. + t.Fatalf("unexpected running channel: test.channel.%d", 3) } } diff --git a/core.go b/core.go index df47249a..29fa268e 100644 --- a/core.go +++ b/core.go @@ -19,7 +19,7 @@ import ( ) var ( - version = "4.28.2" + version = "4.30.0" build = "Custom" codename = "V2Fly, a community-driven edition of V2Ray." intro = "A unified platform for anti-censorship." diff --git a/features/routing/dns/context.go b/features/routing/dns/context.go new file mode 100644 index 00000000..fca58701 --- /dev/null +++ b/features/routing/dns/context.go @@ -0,0 +1,44 @@ +package dns + +//go:generate errorgen + +import ( + "v2ray.com/core/common/net" + "v2ray.com/core/features/dns" + "v2ray.com/core/features/routing" +) + +// ResolvableContext is an implementation of routing.Context, with domain resolving capability. +type ResolvableContext struct { + routing.Context + dnsClient dns.Client + resolvedIPs []net.IP +} + +// GetTargetIPs overrides original routing.Context's implementation. +func (ctx *ResolvableContext) GetTargetIPs() []net.IP { + if ips := ctx.Context.GetTargetIPs(); len(ips) != 0 { + return ips + } + + if len(ctx.resolvedIPs) > 0 { + return ctx.resolvedIPs + } + + if domain := ctx.GetTargetDomain(); len(domain) != 0 { + ips, err := ctx.dnsClient.LookupIP(domain) + if err == nil { + ctx.resolvedIPs = ips + return ips + } + newError("resolve ip for ", domain).Base(err).WriteToLog() + } + + return nil +} + +// ContextWithDNSClient creates a new routing context with domain resolving capability. +// Resolved domain IPs can be retrieved by GetTargetIPs(). +func ContextWithDNSClient(ctx routing.Context, client dns.Client) routing.Context { + return &ResolvableContext{Context: ctx, dnsClient: client} +} diff --git a/features/routing/dns/errors.generated.go b/features/routing/dns/errors.generated.go new file mode 100644 index 00000000..ba70372f --- /dev/null +++ b/features/routing/dns/errors.generated.go @@ -0,0 +1,9 @@ +package dns + +import "v2ray.com/core/common/errors" + +type errPathObjHolder struct{} + +func newError(values ...interface{}) *errors.Error { + return errors.New(values...).WithPathObj(errPathObjHolder{}) +} diff --git a/features/routing/router.go b/features/routing/router.go index f473431a..2acc9651 100644 --- a/features/routing/router.go +++ b/features/routing/router.go @@ -7,12 +7,26 @@ import ( // Router is a feature to choose an outbound tag for the given request. // -// v2ray:api:beta +// v2ray:api:stable type Router interface { features.Feature - // PickRoute returns a tag of an OutboundHandler based on the given context. - PickRoute(ctx Context) (string, error) + // PickRoute returns a route decision based on the given routing context. + PickRoute(ctx Context) (Route, error) +} + +// Route is the routing result of Router feature. +// +// v2ray:api:stable +type Route interface { + // A Route is also a routing context. + Context + + // GetOutboundGroupTags returns the detoured outbound group tags in sequence before a final outbound is chosen. + GetOutboundGroupTags() []string + + // GetOutboundTag returns the tag of the outbound the connection was dispatched to. + GetOutboundTag() string } // RouterType return the type of Router interface. Can be used to implement common.HasType. @@ -31,8 +45,8 @@ func (DefaultRouter) Type() interface{} { } // PickRoute implements Router. -func (DefaultRouter) PickRoute(ctx Context) (string, error) { - return "", common.ErrNoClue +func (DefaultRouter) PickRoute(ctx Context) (Route, error) { + return nil, common.ErrNoClue } // Start implements common.Runnable. diff --git a/features/stats/stats.go b/features/stats/stats.go index a27b441c..73fae0f4 100644 --- a/features/stats/stats.go +++ b/features/stats/stats.go @@ -2,7 +2,10 @@ package stats //go:generate errorgen -import "v2ray.com/core/features" +import ( + "v2ray.com/core/common" + "v2ray.com/core/features" +) // Counter is the interface for stats counters. // @@ -16,18 +19,41 @@ type Counter interface { Add(int64) int64 } -// Channel is the interface for stats channel +// Channel is the interface for stats channel. // // v2ray:api:stable type Channel interface { - // Channel returns the underlying go channel. - Channel() chan interface{} + // Channel is a runnable unit. + common.Runnable + // Publish broadcasts a message through the channel. + Publish(interface{}) // SubscriberCount returns the number of the subscribers. Subscribers() []chan interface{} // Subscribe registers for listening to channel stream and returns a new listener channel. - Subscribe() chan interface{} + Subscribe() (chan interface{}, error) // Unsubscribe unregisters a listener channel from current Channel object. - Unsubscribe(chan interface{}) + Unsubscribe(chan interface{}) error +} + +// SubscribeRunnableChannel subscribes the channel and starts it if there is first subscriber coming. +func SubscribeRunnableChannel(c Channel) (chan interface{}, error) { + if len(c.Subscribers()) == 0 { + if err := c.Start(); err != nil { + return nil, err + } + } + return c.Subscribe() +} + +// UnsubscribeClosableChannel unsubcribes the channel and close it if there is no more subscriber. +func UnsubscribeClosableChannel(c Channel, sub chan interface{}) error { + if err := c.Unsubscribe(sub); err != nil { + return err + } + if len(c.Subscribers()) == 0 { + return c.Close() + } + return nil } // Manager is the interface for stats manager. diff --git a/go.mod b/go.mod index be3148b1..f3b5e9d2 100644 --- a/go.mod +++ b/go.mod @@ -13,11 +13,12 @@ require ( github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841 github.com/stretchr/testify v1.6.1 github.com/xiaokangwang/VSign v0.0.0-20200828155424-dc1c86b73fbf + github.com/xtls/go v0.0.0-20200921133830-416584838a0f go.starlark.net v0.0.0-20200901195727-6e684ef5eeee golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a - golang.org/x/net v0.0.0-20200822124328-c89045814202 + golang.org/x/net v0.0.0-20200904194848-62affa334b73 golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 - golang.org/x/sys v0.0.0-20200831180312-196b9ba8737a + golang.org/x/sys v0.0.0-20200918174421-af09f7315aff google.golang.org/grpc v1.32.0 google.golang.org/protobuf v1.25.0 h12.io/socks v1.0.1 diff --git a/go.sum b/go.sum index 32160f70..a1ff9f8f 100644 --- a/go.sum +++ b/go.sum @@ -6,7 +6,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 h1:BS21ZUJ/B5X2UVUbczfmdWH7GapPWAhxcMsDnjJTU1E= github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= @@ -18,7 +17,6 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -36,28 +34,24 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/h12w/go-socks5 v0.0.0-20200522160539-76189e178364 h1:5XxdakFhqd9dnXoAZy1Mb2R/DZ6D1e+0bGC/JhucGYI= github.com/h12w/go-socks5 v0.0.0-20200522160539-76189e178364/go.mod h1:eDJQioIyy4Yn3MVivT7rv/39gAJTrA7lgmYr8EW950c= -github.com/miekg/dns v1.1.31 h1:sJFOl9BgwbYAWOGEwr61FU28pqsBNdpRBnhGXtO06Oo= github.com/miekg/dns v1.1.31/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= -github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 h1:JhzVVoYvbOACxoUmOs6V/G4D5nPVUW73rKvXxP4XUJc= github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/pires/go-proxyproto v0.1.3 h1:2XEuhsQluSNA5QIQkiUv8PfgZ51sNYIQkq/yFquiSQM= github.com/pires/go-proxyproto v0.1.3/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841 h1:pnfutQFsV7ySmHUeX6ANGfPsBo29RctUvDn8G3rmJVw= github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841/go.mod h1:ET5mVvNjwaGXRgZxO9UZr7X+8eAf87AfIYNwRSp9s4Y= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xiaokangwang/VSign v0.0.0-20200828155424-dc1c86b73fbf h1:d4keT3SwLbrgnEe2zbtijPLgKE15n0ZbvJZzRH/a9GM= github.com/xiaokangwang/VSign v0.0.0-20200828155424-dc1c86b73fbf/go.mod h1:jTwBnzBuqZP3VX/Z65ErYb9zd4anQprSC7N38TmAp1E= +github.com/xtls/go v0.0.0-20200921133830-416584838a0f h1:HNJx0SKT77PmtX0Xj8Ep5ak3cIG19ZFxCYkMa2yJfSg= +github.com/xtls/go v0.0.0-20200921133830-416584838a0f/go.mod h1:5TB2+k58gx4A4g2Nf5miSHNDF6CuAzHKpWBooLAshTs= go.starlark.net v0.0.0-20200901195727-6e684ef5eeee h1:N4eRtIIYHZE5Mw/Km/orb+naLdwAe+lv2HCxRR5rEBw= go.starlark.net v0.0.0-20200901195727-6e684ef5eeee/go.mod h1:f0znQkUKRrkk36XxWbGjMqQM8wGv/xHBVE2qc3B5oFU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -77,13 +71,12 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA= +golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -91,8 +84,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200831180312-196b9ba8737a h1:i47hUS795cOydZI4AwJQCKXOr4BvxzvikwDoDtHhP2Y= -golang.org/x/sys v0.0.0-20200831180312-196b9ba8737a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200918174421-af09f7315aff h1:1CPUrky56AcgSpxz/KfgzQWzfG09u5YOL8MvPYBlrL8= +golang.org/x/sys v0.0.0-20200918174421-af09f7315aff/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -102,7 +95,6 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -114,8 +106,6 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.31.1 h1:SfXqXS5hkufcdZ/mHtYCh53P2b+92WQq/DZcKLgsFRs= -google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -128,11 +118,8 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -h12.io/socks v1.0.1 h1:bXESSI/+hbdrp+22vcc7/JiXjmLH4UWktKdYgGr3ShA= h12.io/socks v1.0.1/go.mod h1:AIhxy1jOId/XCz9BO+EIgNL2rQiPTBNnOfnVnQ+3Eck= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/infra/conf/router.go b/infra/conf/router.go index e2507785..10b51993 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -162,7 +162,7 @@ func loadIP(filename, country string) ([]*router.CIDR, error) { } } - return nil, newError("country not found: " + country) + return nil, newError("country not found in ", filename, ": ", country) } func loadSite(filename, country string) ([]*router.Domain, error) { @@ -181,7 +181,7 @@ func loadSite(filename, country string) ([]*router.Domain, error) { } } - return nil, newError("country not found: " + country) + return nil, newError("list not found in ", filename, ": ", country) } type AttributeMatcher interface { diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 14655fe5..65f4151b 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -16,6 +16,7 @@ import ( "v2ray.com/core/transport/internet/tcp" "v2ray.com/core/transport/internet/tls" "v2ray.com/core/transport/internet/websocket" + "v2ray.com/core/transport/internet/xtls" ) var ( @@ -168,6 +169,7 @@ type HTTPConfig struct { Path string `json:"path"` } +// Build implements Buildable. func (c *HTTPConfig) Build() (proto.Message, error) { config := &http.Config{ Path: c.Path, @@ -184,6 +186,7 @@ type QUICConfig struct { Key string `json:"key"` } +// Build implements Buildable. func (c *QUICConfig) Build() (proto.Message, error) { config := &quic.Config{ Key: c.Key, @@ -225,6 +228,7 @@ type DomainSocketConfig struct { AcceptProxyProtocol bool `json:"acceptProxyProtocol"` } +// Build implements Buildable. func (c *DomainSocketConfig) Build() (proto.Message, error) { return &domainsocket.Config{ Path: c.Path, @@ -234,14 +238,6 @@ func (c *DomainSocketConfig) Build() (proto.Message, error) { }, nil } -type TLSCertConfig struct { - CertFile string `json:"certificateFile"` - CertStr []string `json:"certificate"` - KeyFile string `json:"keyFile"` - KeyStr []string `json:"key"` - Usage string `json:"usage"` -} - func readFileOrString(f string, s []string) ([]byte, error) { if len(f) > 0 { return filesystem.ReadFile(f) @@ -252,6 +248,15 @@ func readFileOrString(f string, s []string) ([]byte, error) { return nil, newError("both file and bytes are empty.") } +type TLSCertConfig struct { + CertFile string `json:"certificateFile"` + CertStr []string `json:"certificate"` + KeyFile string `json:"keyFile"` + KeyStr []string `json:"key"` + Usage string `json:"usage"` +} + +// Build implements Buildable. func (c *TLSCertConfig) Build() (*tls.Certificate, error) { certificate := new(tls.Certificate) @@ -318,6 +323,81 @@ func (c *TLSConfig) Build() (proto.Message, error) { return config, nil } +type XTLSCertConfig struct { + CertFile string `json:"certificateFile"` + CertStr []string `json:"certificate"` + KeyFile string `json:"keyFile"` + KeyStr []string `json:"key"` + Usage string `json:"usage"` +} + +// Build implements Buildable. +func (c *XTLSCertConfig) Build() (*xtls.Certificate, error) { + certificate := new(xtls.Certificate) + + cert, err := readFileOrString(c.CertFile, c.CertStr) + if err != nil { + return nil, newError("failed to parse certificate").Base(err) + } + certificate.Certificate = cert + + if len(c.KeyFile) > 0 || len(c.KeyStr) > 0 { + key, err := readFileOrString(c.KeyFile, c.KeyStr) + if err != nil { + return nil, newError("failed to parse key").Base(err) + } + certificate.Key = key + } + + switch strings.ToLower(c.Usage) { + case "encipherment": + certificate.Usage = xtls.Certificate_ENCIPHERMENT + case "verify": + certificate.Usage = xtls.Certificate_AUTHORITY_VERIFY + case "issue": + certificate.Usage = xtls.Certificate_AUTHORITY_ISSUE + default: + certificate.Usage = xtls.Certificate_ENCIPHERMENT + } + + return certificate, nil +} + +type XTLSConfig struct { + Insecure bool `json:"allowInsecure"` + InsecureCiphers bool `json:"allowInsecureCiphers"` + Certs []*XTLSCertConfig `json:"certificates"` + ServerName string `json:"serverName"` + ALPN *StringList `json:"alpn"` + DisableSessionResumption bool `json:"disableSessionResumption"` + DisableSystemRoot bool `json:"disableSystemRoot"` +} + +// Build implements Buildable. +func (c *XTLSConfig) Build() (proto.Message, error) { + config := new(xtls.Config) + config.Certificate = make([]*xtls.Certificate, len(c.Certs)) + for idx, certConf := range c.Certs { + cert, err := certConf.Build() + if err != nil { + return nil, err + } + config.Certificate[idx] = cert + } + serverName := c.ServerName + config.AllowInsecure = c.Insecure + config.AllowInsecureCiphers = c.InsecureCiphers + if len(c.ServerName) > 0 { + config.ServerName = serverName + } + if c.ALPN != nil && len(*c.ALPN) > 0 { + config.NextProtocol = []string(*c.ALPN) + } + config.DisableSessionResumption = c.DisableSessionResumption + config.DisableSystemRoot = c.DisableSystemRoot + return config, nil +} + type TransportProtocol string // Build implements Buildable. @@ -346,6 +426,7 @@ type SocketConfig struct { TProxy string `json:"tproxy"` } +// Build implements Buildable. func (c *SocketConfig) Build() (*internet.SocketConfig, error) { var tfoSettings internet.SocketConfig_TCPFastOpenState if c.TFO != nil { @@ -376,6 +457,7 @@ type StreamConfig struct { Network *TransportProtocol `json:"network"` Security string `json:"security"` TLSSettings *TLSConfig `json:"tlsSettings"` + XTLSSettings *XTLSConfig `json:"xtlsSettings"` TCPSettings *TCPConfig `json:"tcpSettings"` KCPSettings *KCPConfig `json:"kcpSettings"` WSSettings *WebSocketConfig `json:"wsSettings"` @@ -400,6 +482,9 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { if strings.EqualFold(c.Security, "tls") { tlsSettings := c.TLSSettings if tlsSettings == nil { + if c.XTLSSettings != nil { + return nil, newError(`TLS: Please use "tlsSettings" instead of "xtlsSettings".`) + } tlsSettings = &TLSConfig{} } ts, err := tlsSettings.Build() @@ -410,6 +495,25 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { config.SecuritySettings = append(config.SecuritySettings, tm) config.SecurityType = tm.Type } + if strings.EqualFold(c.Security, "xtls") { + if config.ProtocolName != "tcp" && config.ProtocolName != "domainsocket" { + return nil, newError("XTLS only supports TCP and DomainSocket for now.") + } + xtlsSettings := c.XTLSSettings + if xtlsSettings == nil { + if c.TLSSettings != nil { + return nil, newError(`XTLS: Please use "xtlsSettings" instead of "tlsSettings".`) + } + xtlsSettings = &XTLSConfig{} + } + ts, err := xtlsSettings.Build() + if err != nil { + return nil, newError("Failed to build XTLS config.").Base(err) + } + tm := serial.ToTypedMessage(ts) + config.SecuritySettings = append(config.SecuritySettings, tm) + config.SecurityType = tm.Type + } if c.TCPSettings != nil { ts, err := c.TCPSettings.Build() if err != nil { @@ -463,7 +567,7 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { if c.QUICSettings != nil { qs, err := c.QUICSettings.Build() if err != nil { - return nil, newError("failed to build QUIC config").Base(err) + return nil, newError("Failed to build QUIC config").Base(err) } config.TransportSettings = append(config.TransportSettings, &internet.TransportConfig{ ProtocolName: "quic", @@ -473,7 +577,7 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { if c.SocketSettings != nil { ss, err := c.SocketSettings.Build() if err != nil { - return nil, newError("failed to build sockopt").Base(err) + return nil, newError("Failed to build sockopt").Base(err) } config.SocketSettings = ss } diff --git a/infra/conf/trojan.go b/infra/conf/trojan.go new file mode 100644 index 00000000..4d0c15d8 --- /dev/null +++ b/infra/conf/trojan.go @@ -0,0 +1,135 @@ +package conf + +import ( + "strconv" + + "github.com/golang/protobuf/proto" // nolint: staticcheck + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" + "v2ray.com/core/common/serial" + "v2ray.com/core/proxy/trojan" +) + +// TrojanServerTarget is configuration of a single trojan server +type TrojanServerTarget struct { + Address *Address `json:"address"` + Port uint16 `json:"port"` + Password string `json:"password"` + Email string `json:"email"` + Level byte `json:"level"` +} + +// TrojanClientConfig is configuration of trojan servers +type TrojanClientConfig struct { + Servers []*TrojanServerTarget `json:"servers"` +} + +// Build implements Buildable +func (c *TrojanClientConfig) Build() (proto.Message, error) { + config := new(trojan.ClientConfig) + + if len(c.Servers) == 0 { + return nil, newError("0 Trojan server configured.") + } + + serverSpecs := make([]*protocol.ServerEndpoint, len(c.Servers)) + for idx, rec := range c.Servers { + if rec.Address == nil { + return nil, newError("Trojan server address is not set.") + } + if rec.Port == 0 { + return nil, newError("Invalid Trojan port.") + } + if rec.Password == "" { + return nil, newError("Trojan password is not specified.") + } + account := &trojan.Account{ + Password: rec.Password, + } + trojan := &protocol.ServerEndpoint{ + Address: rec.Address.Build(), + Port: uint32(rec.Port), + User: []*protocol.User{ + { + Level: uint32(rec.Level), + Email: rec.Email, + Account: serial.ToTypedMessage(account), + }, + }, + } + + serverSpecs[idx] = trojan + } + + config.Server = serverSpecs + + return config, nil +} + +// TrojanInboundFallback is fallback configuration +type TrojanInboundFallback struct { + Type string `json:"type"` + Dest string `json:"dest"` +} + +// TrojanUserConfig is user configuration +type TrojanUserConfig struct { + Password string `json:"password"` + Level byte `json:"level"` + Email string `json:"email"` +} + +// TrojanServerConfig is Inbound configuration +type TrojanServerConfig struct { + Clients []*TrojanUserConfig `json:"clients"` + Fallback *TrojanInboundFallback `json:"fallback"` +} + +// Build implements Buildable +func (c *TrojanServerConfig) Build() (proto.Message, error) { + config := new(trojan.ServerConfig) + + if len(c.Clients) == 0 { + return nil, newError("No trojan user settings.") + } + + config.Users = make([]*protocol.User, len(c.Clients)) + for idx, rawUser := range c.Clients { + user := new(protocol.User) + account := &trojan.Account{ + Password: rawUser.Password, + } + + user.Email = rawUser.Email + user.Level = uint32(rawUser.Level) + user.Account = serial.ToTypedMessage(account) + config.Users[idx] = user + } + + if c.Fallback != nil { + fb := &trojan.Fallback{ + Dest: c.Fallback.Dest, + } + + if fb.Type == "" && fb.Dest != "" { + switch fb.Dest[0] { + case '@', '/': + fb.Type = "unix" + default: + if _, err := strconv.Atoi(fb.Dest); err == nil { + fb.Dest = "127.0.0.1:" + fb.Dest + } + if _, _, err := net.SplitHostPort(fb.Dest); err == nil { + fb.Type = "tcp" + } + } + } + if fb.Type == "" { + return nil, newError("please fill in a valid value for trojan fallback type") + } + + config.Fallback = fb + } + + return config, nil +} diff --git a/infra/conf/v2ray.go b/infra/conf/v2ray.go index f8499923..8625f1f5 100644 --- a/infra/conf/v2ray.go +++ b/infra/conf/v2ray.go @@ -11,6 +11,7 @@ import ( "v2ray.com/core/app/proxyman" "v2ray.com/core/app/stats" "v2ray.com/core/common/serial" + "v2ray.com/core/transport/internet/xtls" ) var ( @@ -21,6 +22,7 @@ var ( "socks": func() interface{} { return new(SocksServerConfig) }, "vless": func() interface{} { return new(VLessInboundConfig) }, "vmess": func() interface{} { return new(VMessInboundConfig) }, + "trojan": func() interface{} { return new(TrojanServerConfig) }, "mtproto": func() interface{} { return new(MTProtoServerConfig) }, }, "protocol", "settings") @@ -32,6 +34,7 @@ var ( "socks": func() interface{} { return new(SocksClientConfig) }, "vless": func() interface{} { return new(VLessOutboundConfig) }, "vmess": func() interface{} { return new(VMessOutboundConfig) }, + "trojan": func() interface{} { return new(TrojanClientConfig) }, "mtproto": func() interface{} { return new(MTProtoClientConfig) }, "dns": func() interface{} { return new(DnsOutboundConfig) }, }, "protocol", "settings") @@ -59,6 +62,7 @@ type SniffingConfig struct { DestOverride *StringList `json:"destOverride"` } +// Build implements Buildable. func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) { var p []string if c.DestOverride != nil { @@ -184,6 +188,9 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) { if err != nil { return nil, err } + if ss.SecurityType == serial.GetMessageType(&xtls.Config{}) && !strings.EqualFold(c.Protocol, "vless") { + return nil, newError("XTLS only supports VLESS for now.") + } receiverSettings.StreamSettings = ss } if c.SniffingConfig != nil { @@ -251,6 +258,9 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if err != nil { return nil, err } + if ss.SecurityType == serial.GetMessageType(&xtls.Config{}) && !strings.EqualFold(c.Protocol, "vless") { + return nil, newError("XTLS only supports VLESS for now.") + } senderSettings.StreamSettings = ss } @@ -263,7 +273,15 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { } if c.MuxSettings != nil { - senderSettings.MultiplexSettings = c.MuxSettings.Build() + ms := c.MuxSettings.Build() + if ms != nil && ms.Enabled { + if ss := senderSettings.StreamSettings; ss != nil { + if ss.SecurityType == serial.GetMessageType(&xtls.Config{}) { + return nil, newError("XTLS doesn't support Mux for now.") + } + } + } + senderSettings.MultiplexSettings = ms } settings := []byte("{}") @@ -288,6 +306,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { type StatsConfig struct{} +// Build implements Buildable. func (c *StatsConfig) Build() (*stats.Config, error) { return &stats.Config{}, nil } diff --git a/infra/conf/v2ray_test.go b/infra/conf/v2ray_test.go index 8c59a7ae..e51a42c5 100644 --- a/infra/conf/v2ray_test.go +++ b/infra/conf/v2ray_test.go @@ -404,39 +404,39 @@ func TestConfig_Override(t *testing.T) { }, }, {"combine/newattr", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "old"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "old"}}}, &Config{LogConfig: &LogConfig{}}, "", - &Config{LogConfig: &LogConfig{}, InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "old"}}}}, + &Config{LogConfig: &LogConfig{}, InboundConfigs: []InboundDetourConfig{{Tag: "old"}}}}, {"replace/inbounds", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos0"}, InboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos0"}, {Protocol: "vmess", Tag: "pos1"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}}}, "", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos0"}, InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos0"}, {Tag: "pos1", Protocol: "kcp"}}}}, {"replace/inbounds-replaceall", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos0"}, InboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, InboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos0"}, {Protocol: "vmess", Tag: "pos1"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}, {Tag: "pos2", Protocol: "kcp"}}}, "", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, InboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}, {Tag: "pos2", Protocol: "kcp"}}}}, {"replace/notag-append", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{}, InboundDetourConfig{Protocol: "vmess"}}}, - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{}, {Protocol: "vmess"}}}, + &Config{InboundConfigs: []InboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}}}, "", - &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{}, InboundDetourConfig{Protocol: "vmess"}, InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + &Config{InboundConfigs: []InboundDetourConfig{{}, {Protocol: "vmess"}, {Tag: "pos1", Protocol: "kcp"}}}}, {"replace/outbounds", - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos0"}, {Protocol: "vmess", Tag: "pos1"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}}}, "", - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos0"}, {Tag: "pos1", Protocol: "kcp"}}}}, {"replace/outbounds-prepend", - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, OutboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos0"}, {Protocol: "vmess", Tag: "pos1"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}, {Tag: "pos2", Protocol: "kcp"}}}, "config.json", - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, OutboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos1", Protocol: "kcp"}, {Tag: "pos2", Protocol: "kcp"}}}}, {"replace/outbounds-append", - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos0"}, {Protocol: "vmess", Tag: "pos1"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos2", Protocol: "kcp"}}}, "config_tail.json", - &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Protocol: "vmess", Tag: "pos1"}, OutboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{{Tag: "pos0"}, {Protocol: "vmess", Tag: "pos1"}, {Tag: "pos2", Protocol: "kcp"}}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/infra/conf/vless.go b/infra/conf/vless.go index 7d222bba..00f2f394 100644 --- a/infra/conf/vless.go +++ b/infra/conf/vless.go @@ -48,9 +48,12 @@ func (c *VLessInboundConfig) Build() (proto.Message, error) { return nil, newError(`VLESS clients: invalid user`).Base(err) } - if account.Flow != "" { - return nil, newError(`VLESS clients: "flow" is not available in this version`) + switch account.Flow { + case "", "xtls-rprx-origin": + default: + return nil, newError(`VLESS clients: "flow" only accepts "", "xtls-rprx-origin" in this version`) } + if account.Encryption != "" { return nil, newError(`VLESS clients: "encryption" should not in inbound settings`) } @@ -161,9 +164,12 @@ func (c *VLessOutboundConfig) Build() (proto.Message, error) { return nil, newError(`VLESS users: invalid user`).Base(err) } - if account.Flow != "" { - return nil, newError(`VLESS users: "flow" is not available in this version`) + switch account.Flow { + case "", "xtls-rprx-origin", "xtls-rprx-origin-udp443": + default: + return nil, newError(`VLESS users: "flow" only accepts "", "xtls-rprx-origin", "xtls-rprx-origin-udp443" in this version`) } + if account.Encryption != "none" { return nil, newError(`VLESS users: please add/set "encryption":"none" for every user`) } diff --git a/infra/conf/vless_test.go b/infra/conf/vless_test.go index 12035095..01eb9619 100644 --- a/infra/conf/vless_test.go +++ b/infra/conf/vless_test.go @@ -26,6 +26,7 @@ func TestVLessOutbound(t *testing.T) { "users": [ { "id": "27848739-7e62-4138-9fd3-098a63964b6b", + "flow": "xtls-rprx-origin-udp443", "encryption": "none", "level": 0 } @@ -46,6 +47,7 @@ func TestVLessOutbound(t *testing.T) { { Account: serial.ToTypedMessage(&vless.Account{ Id: "27848739-7e62-4138-9fd3-098a63964b6b", + Flow: "xtls-rprx-origin-udp443", Encryption: "none", }), Level: 0, @@ -69,6 +71,7 @@ func TestVLessInbound(t *testing.T) { "clients": [ { "id": "27848739-7e62-4138-9fd3-098a63964b6b", + "flow": "xtls-rprx-origin", "level": 0, "email": "love@v2fly.org" } @@ -94,7 +97,8 @@ func TestVLessInbound(t *testing.T) { Clients: []*protocol.User{ { Account: serial.ToTypedMessage(&vless.Account{ - Id: "27848739-7e62-4138-9fd3-098a63964b6b", + Id: "27848739-7e62-4138-9fd3-098a63964b6b", + Flow: "xtls-rprx-origin", }), Level: 0, Email: "love@v2fly.org", diff --git a/main/distro/all/all.go b/main/distro/all/all.go index 9df5b6ac..8c3914c7 100644 --- a/main/distro/all/all.go +++ b/main/distro/all/all.go @@ -31,6 +31,7 @@ import ( _ "v2ray.com/core/proxy/mtproto" _ "v2ray.com/core/proxy/shadowsocks" _ "v2ray.com/core/proxy/socks" + _ "v2ray.com/core/proxy/trojan" _ "v2ray.com/core/proxy/vless/inbound" _ "v2ray.com/core/proxy/vless/outbound" _ "v2ray.com/core/proxy/vmess/inbound" @@ -45,6 +46,7 @@ import ( _ "v2ray.com/core/transport/internet/tls" _ "v2ray.com/core/transport/internet/udp" _ "v2ray.com/core/transport/internet/websocket" + _ "v2ray.com/core/transport/internet/xtls" // Transport headers _ "v2ray.com/core/transport/internet/headers/http" diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go new file mode 100644 index 00000000..bd2758b1 --- /dev/null +++ b/proxy/trojan/client.go @@ -0,0 +1,146 @@ +// +build !confonly + +package trojan + +import ( + "context" + "time" + + "v2ray.com/core" + "v2ray.com/core/common" + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" + "v2ray.com/core/common/retry" + "v2ray.com/core/common/session" + "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" + "v2ray.com/core/features/policy" + "v2ray.com/core/transport" + "v2ray.com/core/transport/internet" +) + +// Client is a inbound handler for trojan protocol +type Client struct { + serverPicker protocol.ServerPicker + policyManager policy.Manager +} + +// NewClient create a new trojan client. +func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { + serverList := protocol.NewServerList() + for _, rec := range config.Server { + s, err := protocol.NewServerSpecFromPB(rec) + if err != nil { + return nil, newError("failed to parse server spec").Base(err) + } + serverList.AddServer(s) + } + if serverList.Size() == 0 { + return nil, newError("0 server") + } + + v := core.MustFromContext(ctx) + client := &Client{ + serverPicker: protocol.NewRoundRobinServerPicker(serverList), + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + } + return client, nil +} + +// Process implements OutboundHandler.Process(). +func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { // nolint: funlen + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified") + } + destination := outbound.Target + network := destination.Network + + var server *protocol.ServerSpec + var conn internet.Connection + + err := retry.ExponentialBackoff(5, 100).On(func() error { // nolint: gomnd + server = c.serverPicker.PickServer() + rawConn, err := dialer.Dial(ctx, server.Destination()) + if err != nil { + return err + } + + conn = rawConn + return nil + }) + if err != nil { + return newError("failed to find an available destination").AtWarning().Base(err) + } + newError("tunneling request to ", destination, " via ", server.Destination()).WriteToLog(session.ExportIDToError(ctx)) + + defer conn.Close() + + user := server.PickUser() + account, ok := user.Account.(*MemoryAccount) + if !ok { + return newError("user account is not valid") + } + + sessionPolicy := c.policyManager.ForLevel(user.Level) + ctx, cancel := context.WithCancel(ctx) + timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + + postRequest := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + + var bodyWriter buf.Writer + bufferWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) + connWriter := &ConnWriter{Writer: bufferWriter, Target: destination, Account: account} + + if destination.Network == net.Network_UDP { + bodyWriter = &PacketWriter{Writer: connWriter, Target: destination} + } else { + bodyWriter = connWriter + } + + // write some request payload to buffer + if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { // nolint: lll,gomnd + return newError("failed to write A reqeust payload").Base(err).AtWarning() + } + + // Flush; bufferWriter.WriteMultiBufer now is bufferWriter.writer.WriteMultiBuffer + if err = bufferWriter.SetBuffered(false); err != nil { + return newError("failed to flush payload").Base(err).AtWarning() + } + + if err = buf.Copy(link.Reader, bodyWriter, buf.UpdateActivity(timer)); err != nil { + return newError("failed to transfer request payload").Base(err).AtInfo() + } + + return nil + } + + getResponse := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + + var reader buf.Reader + if network == net.Network_UDP { + reader = &PacketReader{ + Reader: conn, + } + } else { + reader = buf.NewReader(conn) + } + return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)) + } + + var responseDoneAndCloseWriter = task.OnSuccess(getResponse, task.Close(link.Writer)) + if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil { + return newError("connection ends").Base(err) + } + + return nil +} + +func init() { + common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { // nolint: lll + return NewClient(ctx, config.(*ClientConfig)) + })) +} diff --git a/proxy/trojan/config.go b/proxy/trojan/config.go new file mode 100644 index 00000000..817358d4 --- /dev/null +++ b/proxy/trojan/config.go @@ -0,0 +1,50 @@ +package trojan + +import ( + "crypto/sha256" + "encoding/hex" + fmt "fmt" + + "v2ray.com/core/common" + "v2ray.com/core/common/protocol" +) + +// MemoryAccount is an account type converted from Account. +type MemoryAccount struct { + Password string + Key []byte +} + +// AsAccount implements protocol.AsAccount. +func (a *Account) AsAccount() (protocol.Account, error) { + password := a.GetPassword() + key := hexSha224(password) + return &MemoryAccount{ + Password: password, + Key: key, + }, nil +} + +// Equals implements protocol.Account.Equals(). +func (a *MemoryAccount) Equals(another protocol.Account) bool { + if account, ok := another.(*MemoryAccount); ok { + return a.Password == account.Password + } + return false +} + +func hexSha224(password string) []byte { + buf := make([]byte, 56) + hash := sha256.New224() + common.Must2(hash.Write([]byte(password))) + hex.Encode(buf, hash.Sum(nil)) + return buf +} + +func hexString(data []byte) string { + str := "" + for _, v := range data { + str += fmt.Sprintf("%02x", v) + } + return str +} diff --git a/proxy/trojan/config.pb.go b/proxy/trojan/config.pb.go new file mode 100644 index 00000000..2d012ff9 --- /dev/null +++ b/proxy/trojan/config.pb.go @@ -0,0 +1,376 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: proxy/trojan/config.proto + +package trojan + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + protocol "v2ray.com/core/common/protocol" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type Account struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` +} + +func (x *Account) Reset() { + *x = Account{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_trojan_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Account) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Account) ProtoMessage() {} + +func (x *Account) ProtoReflect() protoreflect.Message { + mi := &file_proxy_trojan_config_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Account.ProtoReflect.Descriptor instead. +func (*Account) Descriptor() ([]byte, []int) { + return file_proxy_trojan_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Account) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +type Fallback struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` + Dest string `protobuf:"bytes,2,opt,name=dest,proto3" json:"dest,omitempty"` +} + +func (x *Fallback) Reset() { + *x = Fallback{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_trojan_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Fallback) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Fallback) ProtoMessage() {} + +func (x *Fallback) ProtoReflect() protoreflect.Message { + mi := &file_proxy_trojan_config_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Fallback.ProtoReflect.Descriptor instead. +func (*Fallback) Descriptor() ([]byte, []int) { + return file_proxy_trojan_config_proto_rawDescGZIP(), []int{1} +} + +func (x *Fallback) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +func (x *Fallback) GetDest() string { + if x != nil { + return x.Dest + } + return "" +} + +type ClientConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Server []*protocol.ServerEndpoint `protobuf:"bytes,1,rep,name=server,proto3" json:"server,omitempty"` +} + +func (x *ClientConfig) Reset() { + *x = ClientConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_trojan_config_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ClientConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientConfig) ProtoMessage() {} + +func (x *ClientConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_trojan_config_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientConfig.ProtoReflect.Descriptor instead. +func (*ClientConfig) Descriptor() ([]byte, []int) { + return file_proxy_trojan_config_proto_rawDescGZIP(), []int{2} +} + +func (x *ClientConfig) GetServer() []*protocol.ServerEndpoint { + if x != nil { + return x.Server + } + return nil +} + +type ServerConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Users []*protocol.User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` + Fallback *Fallback `protobuf:"bytes,2,opt,name=fallback,proto3" json:"fallback,omitempty"` +} + +func (x *ServerConfig) Reset() { + *x = ServerConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_trojan_config_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ServerConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ServerConfig) ProtoMessage() {} + +func (x *ServerConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_trojan_config_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ServerConfig.ProtoReflect.Descriptor instead. +func (*ServerConfig) Descriptor() ([]byte, []int) { + return file_proxy_trojan_config_proto_rawDescGZIP(), []int{3} +} + +func (x *ServerConfig) GetUsers() []*protocol.User { + if x != nil { + return x.Users + } + return nil +} + +func (x *ServerConfig) GetFallback() *Fallback { + if x != nil { + return x.Fallback + } + return nil +} + +var File_proxy_trojan_config_proto protoreflect.FileDescriptor + +var file_proxy_trojan_config_proto_rawDesc = []byte{ + 0x0a, 0x19, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0x2f, 0x63, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x76, 0x32, 0x72, + 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x74, 0x72, + 0x6f, 0x6a, 0x61, 0x6e, 0x1a, 0x1a, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x1a, 0x21, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0x25, 0x0a, 0x07, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1a, + 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x32, 0x0a, 0x08, 0x46, 0x61, + 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x65, + 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x64, 0x65, 0x73, 0x74, 0x22, 0x52, + 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x42, + 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, + 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, + 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x22, 0x85, 0x01, 0x0a, 0x0c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x05, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, + 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, + 0x55, 0x73, 0x65, 0x72, 0x52, 0x05, 0x75, 0x73, 0x65, 0x72, 0x73, 0x12, 0x3d, 0x0a, 0x08, 0x66, + 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, + 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, + 0x2e, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0x2e, 0x46, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, + 0x52, 0x08, 0x66, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x42, 0x56, 0x0a, 0x1b, 0x63, 0x6f, + 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x78, 0x79, 0x2e, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0x50, 0x01, 0x5a, 0x1b, 0x76, 0x32, 0x72, + 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, + 0x79, 0x2f, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0xaa, 0x02, 0x17, 0x56, 0x32, 0x52, 0x61, 0x79, + 0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x54, 0x72, 0x6f, 0x6a, + 0x61, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_proxy_trojan_config_proto_rawDescOnce sync.Once + file_proxy_trojan_config_proto_rawDescData = file_proxy_trojan_config_proto_rawDesc +) + +func file_proxy_trojan_config_proto_rawDescGZIP() []byte { + file_proxy_trojan_config_proto_rawDescOnce.Do(func() { + file_proxy_trojan_config_proto_rawDescData = protoimpl.X.CompressGZIP(file_proxy_trojan_config_proto_rawDescData) + }) + return file_proxy_trojan_config_proto_rawDescData +} + +var file_proxy_trojan_config_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_proxy_trojan_config_proto_goTypes = []interface{}{ + (*Account)(nil), // 0: v2ray.core.proxy.trojan.Account + (*Fallback)(nil), // 1: v2ray.core.proxy.trojan.Fallback + (*ClientConfig)(nil), // 2: v2ray.core.proxy.trojan.ClientConfig + (*ServerConfig)(nil), // 3: v2ray.core.proxy.trojan.ServerConfig + (*protocol.ServerEndpoint)(nil), // 4: v2ray.core.common.protocol.ServerEndpoint + (*protocol.User)(nil), // 5: v2ray.core.common.protocol.User +} +var file_proxy_trojan_config_proto_depIdxs = []int32{ + 4, // 0: v2ray.core.proxy.trojan.ClientConfig.server:type_name -> v2ray.core.common.protocol.ServerEndpoint + 5, // 1: v2ray.core.proxy.trojan.ServerConfig.users:type_name -> v2ray.core.common.protocol.User + 1, // 2: v2ray.core.proxy.trojan.ServerConfig.fallback:type_name -> v2ray.core.proxy.trojan.Fallback + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_proxy_trojan_config_proto_init() } +func file_proxy_trojan_config_proto_init() { + if File_proxy_trojan_config_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_proxy_trojan_config_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Account); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_trojan_config_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Fallback); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_trojan_config_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ClientConfig); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_trojan_config_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ServerConfig); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_proxy_trojan_config_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_proxy_trojan_config_proto_goTypes, + DependencyIndexes: file_proxy_trojan_config_proto_depIdxs, + MessageInfos: file_proxy_trojan_config_proto_msgTypes, + }.Build() + File_proxy_trojan_config_proto = out.File + file_proxy_trojan_config_proto_rawDesc = nil + file_proxy_trojan_config_proto_goTypes = nil + file_proxy_trojan_config_proto_depIdxs = nil +} diff --git a/proxy/trojan/config.proto b/proxy/trojan/config.proto new file mode 100644 index 00000000..74092b1e --- /dev/null +++ b/proxy/trojan/config.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package v2ray.core.proxy.trojan; +option csharp_namespace = "V2Ray.Core.Proxy.Trojan"; +option go_package = "v2ray.com/core/proxy/trojan"; +option java_package = "com.v2ray.core.proxy.trojan"; +option java_multiple_files = true; + +import "common/protocol/user.proto"; +import "common/protocol/server_spec.proto"; + +message Account { + string password = 1; +} + +message Fallback { + string type = 1; + string dest = 2; +} + +message ClientConfig { + repeated v2ray.core.common.protocol.ServerEndpoint server = 1; +} + +message ServerConfig { + repeated v2ray.core.common.protocol.User users = 1; + Fallback fallback = 2; +} diff --git a/proxy/trojan/errors.generated.go b/proxy/trojan/errors.generated.go new file mode 100644 index 00000000..d15be699 --- /dev/null +++ b/proxy/trojan/errors.generated.go @@ -0,0 +1,9 @@ +package trojan + +import "v2ray.com/core/common/errors" + +type errPathObjHolder struct{} + +func newError(values ...interface{}) *errors.Error { + return errors.New(values...).WithPathObj(errPathObjHolder{}) +} diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go new file mode 100644 index 00000000..d8007563 --- /dev/null +++ b/proxy/trojan/protocol.go @@ -0,0 +1,282 @@ +package trojan + +import ( + "encoding/binary" + "io" + + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" +) + +var ( + crlf = []byte{'\r', '\n'} + + addrParser = protocol.NewAddressParser( + protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4), // nolint: gomnd + protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6), // nolint: gomnd + protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain), // nolint: gomnd + ) +) + +const ( + maxLength = 8192 + + commandTCP byte = 1 + commandUDP byte = 3 +) + +// ConnWriter is TCP Connection Writer Wrapper for trojan protocol +type ConnWriter struct { + io.Writer + Target net.Destination + Account *MemoryAccount + headerSent bool +} + +// Write implements io.Writer +func (c *ConnWriter) Write(p []byte) (n int, err error) { + if !c.headerSent { + if err := c.writeHeader(); err != nil { + return 0, newError("failed to write request header").Base(err) + } + } + + return c.Writer.Write(p) +} + +// WriteMultiBuffer implements buf.Writer +func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + defer buf.ReleaseMulti(mb) + + for _, b := range mb { + if !b.IsEmpty() { + if _, err := c.Write(b.Bytes()); err != nil { + return err + } + } + } + + return nil +} + +func (c *ConnWriter) writeHeader() error { + buffer := buf.StackNew() + defer buffer.Release() + + command := commandTCP + if c.Target.Network == net.Network_UDP { + command = commandUDP + } + + if _, err := buffer.Write(c.Account.Key); err != nil { + return err + } + if _, err := buffer.Write(crlf); err != nil { + return err + } + if err := buffer.WriteByte(command); err != nil { + return err + } + if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil { + return err + } + if _, err := buffer.Write(crlf); err != nil { + return err + } + + _, err := c.Writer.Write(buffer.Bytes()) + if err == nil { + c.headerSent = true + } + + return err +} + +// PacketWriter UDP Connection Writer Wrapper for trojan protocol +type PacketWriter struct { + io.Writer + Target net.Destination +} + +// WriteMultiBuffer implements buf.Writer +func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + b := make([]byte, maxLength) + for !mb.IsEmpty() { + var length int + mb, length = buf.SplitBytes(mb, b) + if _, err := w.writePacket(b[:length], w.Target); err != nil { + buf.ReleaseMulti(mb) + return err + } + } + + return nil +} + +// WriteMultiBufferWithMetadata writes udp packet with destination specified +func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error { + b := make([]byte, maxLength) + for !mb.IsEmpty() { + var length int + mb, length = buf.SplitBytes(mb, b) + if _, err := w.writePacket(b[:length], dest); err != nil { + buf.ReleaseMulti(mb) + return err + } + } + + return nil +} + +func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam + buffer := buf.StackNew() + defer buffer.Release() + + length := len(payload) + lengthBuf := [2]byte{} + binary.BigEndian.PutUint16(lengthBuf[:], uint16(length)) + if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil { + return 0, err + } + if _, err := buffer.Write(lengthBuf[:]); err != nil { + return 0, err + } + if _, err := buffer.Write(crlf); err != nil { + return 0, err + } + if _, err := buffer.Write(payload); err != nil { + return 0, err + } + _, err := w.Write(buffer.Bytes()) + if err != nil { + return 0, err + } + + return length, nil +} + +// ConnReader is TCP Connection Reader Wrapper for trojan protocol +type ConnReader struct { + io.Reader + Target net.Destination + headerParsed bool +} + +// ParseHeader parses the trojan protocol header +func (c *ConnReader) ParseHeader() error { + var crlf [2]byte + var command [1]byte + var hash [56]byte + if _, err := io.ReadFull(c.Reader, hash[:]); err != nil { + return newError("failed to read user hash").Base(err) + } + + if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil { + return newError("failed to read crlf").Base(err) + } + + if _, err := io.ReadFull(c.Reader, command[:]); err != nil { + return newError("failed to read command").Base(err) + } + + network := net.Network_TCP + if command[0] == commandUDP { + network = net.Network_UDP + } + + addr, port, err := addrParser.ReadAddressPort(nil, c.Reader) + if err != nil { + return newError("failed to read address and port").Base(err) + } + c.Target = net.Destination{Network: network, Address: addr, Port: port} + + if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil { + return newError("failed to read crlf").Base(err) + } + + c.headerParsed = true + return nil +} + +// Read implements io.Reader +func (c *ConnReader) Read(p []byte) (int, error) { + if !c.headerParsed { + if err := c.ParseHeader(); err != nil { + return 0, err + } + } + + return c.Reader.Read(p) +} + +// ReadMultiBuffer implements buf.Reader +func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + b := buf.New() + _, err := b.ReadFrom(c) + return buf.MultiBuffer{b}, err +} + +// PacketPayload combines udp payload and destination +type PacketPayload struct { + Target net.Destination + Buffer buf.MultiBuffer +} + +// PacketReader is UDP Connection Reader Wrapper for trojan protocol +type PacketReader struct { + io.Reader +} + +// ReadMultiBuffer implements buf.Reader +func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + p, err := r.ReadMultiBufferWithMetadata() + if p != nil { + return p.Buffer, err + } + return nil, err +} + +// ReadMultiBufferWithMetadata reads udp packet with destination +func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { + addr, port, err := addrParser.ReadAddressPort(nil, r) + if err != nil { + return nil, newError("failed to read address and port").Base(err) + } + + var lengthBuf [2]byte + if _, err := io.ReadFull(r, lengthBuf[:]); err != nil { + return nil, newError("failed to read payload length").Base(err) + } + + remain := int(binary.BigEndian.Uint16(lengthBuf[:])) + if remain > maxLength { + return nil, newError("oversize payload") + } + + var crlf [2]byte + if _, err := io.ReadFull(r, crlf[:]); err != nil { + return nil, newError("failed to read crlf").Base(err) + } + + dest := net.UDPDestination(addr, port) + var mb buf.MultiBuffer + for remain > 0 { + length := buf.Size + if remain < length { + length = remain + } + + b := buf.New() + mb = append(mb, b) + n, err := b.ReadFullFrom(r, int32(length)) + if err != nil { + buf.ReleaseMulti(mb) + return nil, newError("failed to read payload").Base(err) + } + + remain -= int(n) + } + + return &PacketPayload{Target: dest, Buffer: mb}, nil +} diff --git a/proxy/trojan/protocol_test.go b/proxy/trojan/protocol_test.go new file mode 100644 index 00000000..c30eabef --- /dev/null +++ b/proxy/trojan/protocol_test.go @@ -0,0 +1,91 @@ +package trojan_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "v2ray.com/core/common" + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" + . "v2ray.com/core/proxy/trojan" +) + +func toAccount(a *Account) protocol.Account { + account, err := a.AsAccount() + common.Must(err) + return account +} + +func TestTCPRequest(t *testing.T) { + user := &protocol.MemoryUser{ + Email: "love@v2ray.com", + Account: toAccount(&Account{ + Password: "password", + }), + } + payload := []byte("test string") + data := buf.New() + common.Must2(data.Write(payload)) + + buffer := buf.New() + defer buffer.Release() + + destination := net.Destination{Network: net.Network_TCP, Address: net.LocalHostIP, Port: 1234} + writer := &ConnWriter{Writer: buffer, Target: destination, Account: user.Account.(*MemoryAccount)} + common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{data})) + + reader := &ConnReader{Reader: buffer} + common.Must(reader.ParseHeader()) + + if r := cmp.Diff(reader.Target, destination); r != "" { + t.Error("destination: ", r) + } + + decodedData, err := reader.ReadMultiBuffer() + common.Must(err) + if r := cmp.Diff(decodedData[0].Bytes(), payload); r != "" { + t.Error("data: ", r) + } +} + +func TestUDPRequest(t *testing.T) { + user := &protocol.MemoryUser{ + Email: "love@v2ray.com", + Account: toAccount(&Account{ + Password: "password", + }), + } + payload := []byte("test string") + data := buf.New() + common.Must2(data.Write(payload)) + + buffer := buf.New() + defer buffer.Release() + + destination := net.Destination{Network: net.Network_UDP, Address: net.LocalHostIP, Port: 1234} + writer := &PacketWriter{Writer: &ConnWriter{Writer: buffer, Target: destination, Account: user.Account.(*MemoryAccount)}, Target: destination} + common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{data})) + + connReader := &ConnReader{Reader: buffer} + common.Must(connReader.ParseHeader()) + + packetReader := &PacketReader{Reader: connReader} + p, err := packetReader.ReadMultiBufferWithMetadata() + common.Must(err) + + if p.Buffer.IsEmpty() { + t.Error("no request data") + } + + if r := cmp.Diff(p.Target, destination); r != "" { + t.Error("destination: ", r) + } + + mb, decoded := buf.SplitFirst(p.Buffer) + buf.ReleaseMulti(mb) + + if r := cmp.Diff(decoded.Bytes(), payload); r != "" { + t.Error("data: ", r) + } +} diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go new file mode 100644 index 00000000..577828c3 --- /dev/null +++ b/proxy/trojan/server.go @@ -0,0 +1,290 @@ +// +build !confonly + +package trojan + +import ( + "context" + "io" + "time" + + "v2ray.com/core" + "v2ray.com/core/common" + "v2ray.com/core/common/buf" + "v2ray.com/core/common/errors" + "v2ray.com/core/common/log" + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" + udp_proto "v2ray.com/core/common/protocol/udp" + "v2ray.com/core/common/retry" + "v2ray.com/core/common/session" + "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" + "v2ray.com/core/features/policy" + "v2ray.com/core/features/routing" + "v2ray.com/core/transport/internet" + "v2ray.com/core/transport/internet/udp" +) + +func init() { + common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { // nolint: lll + return NewServer(ctx, config.(*ServerConfig)) + })) +} + +// Server is an inbound connection handler that handles messages in trojan protocol. +type Server struct { + validator *Validator + policyManager policy.Manager + config *ServerConfig +} + +// NewServer creates a new trojan inbound handler. +func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { + validator := new(Validator) + for _, user := range config.Users { + u, err := user.ToMemoryUser() + if err != nil { + return nil, newError("failed to get trojan user").Base(err).AtError() + } + + if err := validator.Add(u); err != nil { + return nil, newError("failed to add user").Base(err).AtError() + } + } + + v := core.MustFromContext(ctx) + server := &Server{ + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + validator: validator, + config: config, + } + + return server, nil +} + +// Network implements proxy.Inbound.Network(). +func (s *Server) Network() []net.Network { + return []net.Network{net.Network_TCP} +} + +// Process implements proxy.Inbound.Process(). +func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { // nolint: funlen,lll + sessionPolicy := s.policyManager.ForLevel(0) + if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { + return newError("unable to set read deadline").Base(err).AtWarning() + } + + buffer := buf.New() + defer buffer.Release() + + n, err := buffer.ReadFrom(conn) + if err != nil { + return newError("failed to read first request").Base(err) + } + + bufferedReader := &buf.BufferedReader{ + Reader: buf.NewReader(conn), + Buffer: buf.MultiBuffer{buffer}, + } + + var user *protocol.MemoryUser + fallbackEnabled := s.config.Fallback != nil + shouldFallback := false + if n < 56 { // nolint: gomnd + // invalid protocol + log.Record(&log.AccessMessage{ + From: conn.RemoteAddr(), + To: "", + Status: log.AccessRejected, + Reason: newError("not trojan protocol"), + }) + + shouldFallback = true + } else { + user = s.validator.Get(hexString(buffer.BytesTo(56))) // nolint: gomnd + if user == nil { + // invalid user, let's fallback + log.Record(&log.AccessMessage{ + From: conn.RemoteAddr(), + To: "", + Status: log.AccessRejected, + Reason: newError("not a valid user"), + }) + + shouldFallback = true + } + } + + if fallbackEnabled && shouldFallback { + return s.fallback(ctx, sessionPolicy, bufferedReader, buf.NewWriter(conn)) + } else if shouldFallback { + return newError("invalid protocol or invalid user") + } + + clientReader := &ConnReader{Reader: bufferedReader} + if err := clientReader.ParseHeader(); err != nil { + log.Record(&log.AccessMessage{ + From: conn.RemoteAddr(), + To: "", + Status: log.AccessRejected, + Reason: err, + }) + return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) + } + + destination := clientReader.Target + if err := conn.SetReadDeadline(time.Time{}); err != nil { + return newError("unable to set read deadline").Base(err).AtWarning() + } + + inbound := session.InboundFromContext(ctx) + if inbound == nil { + panic("no inbound metadata") + } + inbound.User = user + sessionPolicy = s.policyManager.ForLevel(user.Level) + + if destination.Network == net.Network_UDP { // handle udp request + return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) + } + + // handle tcp request + + log.ContextWithAccessMessage(ctx, &log.AccessMessage{ + From: conn.RemoteAddr(), + To: destination, + Status: log.AccessAccepted, + Reason: "", + Email: user.Email, + }) + + newError("received request for ", destination).WriteToLog(session.ExportIDToError(ctx)) + return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) +} + +func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { // nolint: lll + udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { + common.Must(clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source)) + }) + + inbound := session.InboundFromContext(ctx) + user := inbound.User + + for { + select { + case <-ctx.Done(): + return nil + default: + p, err := clientReader.ReadMultiBufferWithMetadata() + if err != nil { + if errors.Cause(err) != io.EOF { + return newError("unexpected EOF").Base(err) + } + return nil + } + + log.ContextWithAccessMessage(ctx, &log.AccessMessage{ + From: inbound.Source, + To: p.Target, + Status: log.AccessAccepted, + Reason: "", + Email: user.Email, + }) + newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) + + for _, b := range p.Buffer { + udpServer.Dispatch(ctx, p.Target, b) + } + } + } +} + +func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session, + destination net.Destination, + clientReader buf.Reader, + clientWriter buf.Writer, dispatcher routing.Dispatcher) error { + ctx, cancel := context.WithCancel(ctx) + timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + + link, err := dispatcher.Dispatch(ctx, destination) + if err != nil { + return newError("failed to dispatch request to ", destination).Base(err) + } + + requestDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + + if err := buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer)); err != nil { + return newError("failed to transfer request").Base(err) + } + return nil + } + + responseDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + + if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil { + return newError("failed to write response").Base(err) + } + return nil + } + + var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { + common.Must(common.Interrupt(link.Reader)) + common.Must(common.Interrupt(link.Writer)) + return newError("connection ends").Base(err) + } + + return nil +} + +func (s *Server) fallback(ctx context.Context, sessionPolicy policy.Session, requestReader buf.Reader, responseWriter buf.Writer) error { // nolint: lll + ctx, cancel := context.WithCancel(ctx) + timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + + var conn net.Conn + var err error + fb := s.config.Fallback + if err := retry.ExponentialBackoff(5, 100).On(func() error { // nolint: gomnd + var dialer net.Dialer + conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest) + if err != nil { + return err + } + return nil + }); err != nil { + return newError("failed to dial to " + fb.Dest).Base(err).AtWarning() + } + defer conn.Close() + + serverReader := buf.NewReader(conn) + serverWriter := buf.NewWriter(conn) + + requestDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + + if err := buf.Copy(requestReader, serverWriter, buf.UpdateActivity(timer)); err != nil { + return newError("failed to fallback request payload").Base(err).AtInfo() + } + return nil + } + + responseDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + if err := buf.Copy(serverReader, responseWriter, buf.UpdateActivity(timer)); err != nil { + return newError("failed to deliver response payload").Base(err).AtInfo() + } + return nil + } + + if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(serverWriter)), task.OnSuccess(responseDone, task.Close(responseWriter))); err != nil { // nolint: lll + common.Must(common.Interrupt(serverReader)) + common.Must(common.Interrupt(serverWriter)) + return newError("fallback ends").Base(err).AtInfo() + } + + return nil +} diff --git a/proxy/trojan/trojan.go b/proxy/trojan/trojan.go new file mode 100644 index 00000000..73b3154f --- /dev/null +++ b/proxy/trojan/trojan.go @@ -0,0 +1 @@ +package trojan diff --git a/proxy/trojan/validator.go b/proxy/trojan/validator.go new file mode 100644 index 00000000..1c7926e3 --- /dev/null +++ b/proxy/trojan/validator.go @@ -0,0 +1,28 @@ +package trojan + +import ( + "sync" + + "v2ray.com/core/common/protocol" +) + +// Validator stores valid trojan users +type Validator struct { + users sync.Map +} + +// Add a trojan user +func (v *Validator) Add(u *protocol.MemoryUser) error { + user := u.Account.(*MemoryAccount) + v.users.Store(hexString(user.Key), u) + return nil +} + +// Get user with hashed key, nil if user doesn't exist. +func (v *Validator) Get(hash string) *protocol.MemoryUser { + u, _ := v.users.Load(hash) + if u != nil { + return u.(*protocol.MemoryUser) + } + return nil +} diff --git a/proxy/vless/encoding/addons.go b/proxy/vless/encoding/addons.go index a69c4109..b35e609e 100644 --- a/proxy/vless/encoding/addons.go +++ b/proxy/vless/encoding/addons.go @@ -9,11 +9,25 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/protocol" + "v2ray.com/core/proxy/vless" ) func EncodeHeaderAddons(buffer *buf.Buffer, addons *Addons) error { switch addons.Flow { + case vless.XRO: + + bytes, err := proto.Marshal(addons) + if err != nil { + return newError("failed to marshal addons protobuf value").Base(err) + } + if err := buffer.WriteByte(byte(len(bytes))); err != nil { + return newError("failed to write addons protobuf length").Base(err) + } + if _, err := buffer.Write(bytes); err != nil { + return newError("failed to write addons protobuf value").Base(err) + } + default: if err := buffer.WriteByte(0); err != nil { @@ -64,10 +78,14 @@ func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, addons switch addons.Flow { default: - return buf.NewWriter(writer) + if request.Command == protocol.RequestCommandUDP { + return NewMultiLengthPacketWriter(writer.(buf.Writer)) + } } + return buf.NewWriter(writer) + } // DecodeBodyAddons returns a Reader from which caller can fetch decrypted body. @@ -76,8 +94,118 @@ func DecodeBodyAddons(reader io.Reader, request *protocol.RequestHeader, addons switch addons.Flow { default: - return buf.NewReader(reader) + if request.Command == protocol.RequestCommandUDP { + return NewLengthPacketReader(reader) + } } + return buf.NewReader(reader) + +} + +func NewMultiLengthPacketWriter(writer buf.Writer) *MultiLengthPacketWriter { + return &MultiLengthPacketWriter{ + Writer: writer, + } +} + +type MultiLengthPacketWriter struct { + buf.Writer +} + +func (w *MultiLengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + defer buf.ReleaseMulti(mb) + mb2Write := make(buf.MultiBuffer, 0, len(mb)+1) + for _, b := range mb { + length := b.Len() + if length == 0 || length+2 > buf.Size { + continue + } + eb := buf.New() + if err := eb.WriteByte(byte(length >> 8)); err != nil { + eb.Release() + continue + } + if err := eb.WriteByte(byte(length)); err != nil { + eb.Release() + continue + } + if _, err := eb.Write(b.Bytes()); err != nil { + eb.Release() + continue + } + mb2Write = append(mb2Write, eb) + } + if mb2Write.IsEmpty() { + return nil + } + return w.Writer.WriteMultiBuffer(mb2Write) +} + +func NewLengthPacketWriter(writer io.Writer) *LengthPacketWriter { + return &LengthPacketWriter{ + Writer: writer, + cache: make([]byte, 0, 65536), + } +} + +type LengthPacketWriter struct { + io.Writer + cache []byte +} + +func (w *LengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + length := mb.Len() // none of mb is nil + //fmt.Println("Write", length) + if length == 0 { + return nil + } + defer func() { + w.cache = w.cache[:0] + }() + w.cache = append(w.cache, byte(length>>8), byte(length)) + for i, b := range mb { + w.cache = append(w.cache, b.Bytes()...) + b.Release() + mb[i] = nil + } + if _, err := w.Write(w.cache); err != nil { + return newError("failed to write a packet").Base(err) + } + return nil +} + +func NewLengthPacketReader(reader io.Reader) *LengthPacketReader { + return &LengthPacketReader{ + Reader: reader, + cache: make([]byte, 2), + } +} + +type LengthPacketReader struct { + io.Reader + cache []byte +} + +func (r *LengthPacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + if _, err := io.ReadFull(r.Reader, r.cache); err != nil { // maybe EOF + return nil, newError("failed to read packet length").Base(err) + } + length := int32(r.cache[0])<<8 | int32(r.cache[1]) + //fmt.Println("Read", length) + mb := make(buf.MultiBuffer, 0, length/buf.Size+1) + for length > 0 { + size := length + if size > buf.Size { + size = buf.Size + } + length -= size + b := buf.New() + if _, err := b.ReadFullFrom(r.Reader, size); err != nil { + return nil, newError("failed to read packet payload").Base(err) + } + mb = append(mb, b) + } + return mb, nil } diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index a0d5b54f..b56b30ef 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -153,23 +153,23 @@ func EncodeResponseHeader(writer io.Writer, request *protocol.RequestHeader, res } // DecodeResponseHeader decodes and returns (if successful) a ResponseHeader from an input stream. -func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader, responseAddons *Addons) error { +func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*Addons, error) { buffer := buf.StackNew() defer buffer.Release() if _, err := buffer.ReadFullFrom(reader, 1); err != nil { - return newError("failed to read response version").Base(err) + return nil, newError("failed to read response version").Base(err) } if buffer.Byte(0) != request.Version { - return newError("unexpected response version. Expecting ", int(request.Version), " but actually ", int(buffer.Byte(0))) + return nil, newError("unexpected response version. Expecting ", int(request.Version), " but actually ", int(buffer.Byte(0))) } responseAddons, err := DecodeHeaderAddons(&buffer, reader) if err != nil { - return newError("failed to decode response header addons").Base(err) + return nil, newError("failed to decode response header addons").Base(err) } - return nil + return responseAddons, nil } diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 865aafc1..803fce33 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -6,7 +6,6 @@ package inbound import ( "context" - "encoding/hex" "io" "strconv" "time" @@ -17,6 +16,7 @@ import ( "v2ray.com/core/common/errors" "v2ray.com/core/common/log" "v2ray.com/core/common/net" + "v2ray.com/core/common/platform" "v2ray.com/core/common/protocol" "v2ray.com/core/common/retry" "v2ray.com/core/common/session" @@ -30,6 +30,11 @@ import ( "v2ray.com/core/proxy/vless/encoding" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" + "v2ray.com/core/transport/internet/xtls" +) + +var ( + xtls_show = false ) func init() { @@ -43,6 +48,13 @@ func init() { } return New(ctx, config.(*Config), dc) })) + + const defaultFlagValue = "NOT_DEFINED_AT_ALL" + + xtlsShow := platform.NewEnvFlag("v2ray.vless.xtls.show").GetValue(func() string { return defaultFlagValue }) + if xtlsShow == "true" { + xtls_show = true + } } // Handler is an inbound connection handler that handles messages in VLess protocol. @@ -135,6 +147,11 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i sid := session.ExportIDToError(ctx) + iConn := connection + if statConn, ok := iConn.(*internet.StatCouterConnection); ok { + iConn = statConn.Connection + } + sessionPolicy := h.policyManager.ForLevel(0) if err := connection.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { return newError("unable to set read deadline").Base(err).AtWarning() @@ -183,16 +200,15 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i alpn := "" if len(apfb) > 1 || apfb[""] == nil { - iConn := connection - if statConn, ok := iConn.(*internet.StatCouterConnection); ok { - iConn = statConn.Connection - } if tlsConn, ok := iConn.(*tls.Conn); ok { alpn = tlsConn.ConnectionState().NegotiatedProtocol newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) - if apfb[alpn] == nil { - alpn = "" - } + } else if xtlsConn, ok := iConn.(*xtls.Conn); ok { + alpn = xtlsConn.ConnectionState().NegotiatedProtocol + newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid) + } + if apfb[alpn] == nil { + alpn = "" } } pfb := apfb[alpn] @@ -307,18 +323,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i pro.Write(net.ParseIP(remoteAddr).To16()) pro.Write(net.ParseIP(localAddr).To16()) } - p1, _ := strconv.ParseInt(remotePort, 10, 64) - b1, _ := hex.DecodeString(strconv.FormatInt(p1, 16)) - p2, _ := strconv.ParseInt(localPort, 10, 64) - b2, _ := hex.DecodeString(strconv.FormatInt(p2, 16)) - if len(b1) == 1 { - pro.WriteByte(0) - } - pro.Write(b1) - if len(b2) == 1 { - pro.WriteByte(0) - } - pro.Write(b2) + p1, _ := strconv.ParseUint(remotePort, 10, 16) + p2, _ := strconv.ParseUint(localPort, 10, 16) + pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)}) } if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil { return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning() @@ -376,6 +383,34 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i } inbound.User = request.User + account := request.User.Account.(*vless.MemoryAccount) + + responseAddons := &encoding.Addons{ + //Flow: requestAddons.Flow, + } + + switch requestAddons.Flow { + case vless.XRO: + if account.Flow == vless.XRO { + switch request.Command { + case protocol.RequestCommandMux: + return newError(vless.XRO + " doesn't support Mux").AtWarning() + case protocol.RequestCommandUDP: + return newError(vless.XRO + " doesn't support UDP").AtWarning() + case protocol.RequestCommandTCP: + if xtlsConn, ok := iConn.(*xtls.Conn); ok { + xtlsConn.RPRX = true + xtlsConn.SHOW = xtls_show + xtlsConn.MARK = "XTLS" + } else { + return newError(`failed to use ` + vless.XRO + `, maybe "security" is not "xtls"`).AtWarning() + } + } + } else { + return newError(account.ID.String() + " is not able to use " + vless.XRO).AtWarning() + } + } + if request.Command != protocol.RequestCommandMux { ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: connection.RemoteAddr(), @@ -396,8 +431,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return newError("failed to dispatch request to ", request.Destination()).Base(err).AtWarning() } - serverReader := link.Reader - serverWriter := link.Writer + serverReader := link.Reader // .(*pipe.Reader) + serverWriter := link.Writer // .(*pipe.Writer) postRequest := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) @@ -416,10 +451,6 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i getResponse := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) - responseAddons := &encoding.Addons{ - Flow: requestAddons.Flow, - } - bufferWriter := buf.NewBufferedWriter(buf.NewWriter(connection)) if err := encoding.EncodeResponseHeader(bufferWriter, request, responseAddons); err != nil { return newError("failed to encode response header").Base(err).AtWarning() diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 161ada2f..3be02778 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -12,6 +12,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" + "v2ray.com/core/common/platform" "v2ray.com/core/common/protocol" "v2ray.com/core/common/retry" "v2ray.com/core/common/session" @@ -22,12 +23,24 @@ import ( "v2ray.com/core/proxy/vless/encoding" "v2ray.com/core/transport" "v2ray.com/core/transport/internet" + "v2ray.com/core/transport/internet/xtls" +) + +var ( + xtls_show = false ) func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { return New(ctx, config.(*Config)) })) + + const defaultFlagValue = "NOT_DEFINED_AT_ALL" + + xtlsShow := platform.NewEnvFlag("v2ray.vless.xtls.show").GetValue(func() string { return defaultFlagValue }) + if xtlsShow == "true" { + xtls_show = true + } } // Handler is an outbound connection handler for VLess protocol. @@ -60,13 +73,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) { } // Process implements proxy.Outbound.Process(). -func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { +func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { var rec *protocol.ServerSpec var conn internet.Connection if err := retry.ExponentialBackoff(5, 200).On(func() error { - rec = v.serverPicker.PickServer() + rec = h.serverPicker.PickServer() var err error conn, err = dialer.Dial(ctx, rec.Destination()) if err != nil { @@ -78,6 +91,11 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } defer conn.Close() // nolint: errcheck + iConn := conn + if statConn, ok := iConn.(*internet.StatCouterConnection); ok { + iConn = statConn.Connection + } + outbound := session.OutboundFromContext(ctx) if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified").AtError() @@ -108,12 +126,38 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte Flow: account.Flow, } - sessionPolicy := v.policyManager.ForLevel(request.User.Level) + switch requestAddons.Flow { + case vless.XRO, vless.XRO + "-udp443": + switch request.Command { + case protocol.RequestCommandMux: + return newError(vless.XRO + " doesn't support Mux").AtWarning() + case protocol.RequestCommandUDP: + if requestAddons.Flow == vless.XRO && request.Port == 443 { + return newError(vless.XRO + " stopped UDP/443").AtWarning() + } + requestAddons.Flow = "" + case protocol.RequestCommandTCP: + if xtlsConn, ok := iConn.(*xtls.Conn); ok { + xtlsConn.RPRX = true + xtlsConn.SHOW = xtls_show + xtlsConn.MARK = "XTLS" + } else { + return newError(`failed to use ` + vless.XRO + `, maybe "security" is not "xtls"`).AtWarning() + } + requestAddons.Flow = vless.XRO + } + default: + if _, ok := iConn.(*xtls.Conn); ok { + panic(`To avoid misunderstanding, you must fill in VLESS "flow" when using XTLS.`) + } + } + + sessionPolicy := h.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) - clientReader := link.Reader - clientWriter := link.Writer + clientReader := link.Reader // .(*pipe.Reader) + clientWriter := link.Writer // .(*pipe.Writer) postRequest := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) @@ -151,9 +195,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte getResponse := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) - responseAddons := new(encoding.Addons) - - if err := encoding.DecodeResponseHeader(conn, request, responseAddons); err != nil { + responseAddons, err := encoding.DecodeResponseHeader(conn, request) + if err != nil { return newError("failed to decode response header").Base(err).AtWarning() } diff --git a/proxy/vless/vless.go b/proxy/vless/vless.go index 9e6dc7ab..ea51e563 100644 --- a/proxy/vless/vless.go +++ b/proxy/vless/vless.go @@ -6,3 +6,7 @@ package vless //go:generate errorgen + +const ( + XRO = "xtls-rprx-origin" +) diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index cdff7c8a..78945532 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -12,8 +12,6 @@ import ( "hash" "hash/fnv" "io" - "os" - vmessaead "v2ray.com/core/proxy/vmess/aead" "golang.org/x/crypto/chacha20poly1305" @@ -25,6 +23,7 @@ import ( "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" "v2ray.com/core/proxy/vmess" + vmessaead "v2ray.com/core/proxy/vmess/aead" ) func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte { @@ -37,6 +36,7 @@ func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte { // ClientSession stores connection session info for VMess client. type ClientSession struct { + isAEAD bool idHash protocol.IDHash requestBodyKey [16]byte requestBodyIV [16]byte @@ -44,35 +44,23 @@ type ClientSession struct { responseBodyIV [16]byte responseReader io.Reader responseHeader byte - - isAEADRequest bool } // NewClientSession creates a new ClientSession. -func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSession { +func NewClientSession(isAEAD bool, idHash protocol.IDHash, ctx context.Context) *ClientSession { + + session := &ClientSession{ + isAEAD: isAEAD, + idHash: idHash, + } + randomBytes := make([]byte, 33) // 16 + 16 + 1 common.Must2(rand.Read(randomBytes)) - - session := &ClientSession{} - - session.isAEADRequest = false - - if ctxValueAlterID := ctx.Value(vmess.AlterID); ctxValueAlterID != nil { - if ctxValueAlterID == 0 { - session.isAEADRequest = true - } - } - - if vmessAeadDisable, vmessAeadDisableFound := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); vmessAeadDisableFound { - if vmessAeadDisable == "true" { - session.isAEADRequest = false - } - } - copy(session.requestBodyKey[:], randomBytes[:16]) copy(session.requestBodyIV[:], randomBytes[16:32]) session.responseHeader = randomBytes[32] - if !session.isAEADRequest { + + if !session.isAEAD { session.responseBodyKey = md5.Sum(session.requestBodyKey[:]) session.responseBodyIV = md5.Sum(session.requestBodyIV[:]) } else { @@ -82,15 +70,13 @@ func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSessio copy(session.responseBodyIV[:], BodyIV[:16]) } - session.idHash = idHash - return session } func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() account := header.User.Account.(*vmess.MemoryAccount) - if !c.isAEADRequest { + if !c.isAEAD { idHash := c.idHash(account.AnyValidID().Bytes()) common.Must2(serial.WriteUint64(idHash, uint64(timestamp))) common.Must2(writer.Write(idHash.Sum(nil))) @@ -126,7 +112,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ fnv1a.Sum(hashBytes[:0]) } - if !c.isAEADRequest { + if !c.isAEAD { iv := hashTimestamp(md5.New(), timestamp) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) @@ -203,7 +189,7 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { - if !c.isAEADRequest { + if !c.isAEAD { aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) c.responseReader = crypto.NewCryptionReader(aesStream, reader) } else { @@ -274,7 +260,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon header.Command = command } } - if c.isAEADRequest { + if c.isAEAD { aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) c.responseReader = crypto.NewCryptionReader(aesStream, reader) } diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index bc7eecd3..c0f938b7 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -43,7 +43,7 @@ func TestRequestSerialization(t *testing.T) { } buffer := buf.New() - client := NewClientSession(protocol.DefaultIDHash, context.TODO()) + client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() @@ -93,7 +93,7 @@ func TestInvalidRequest(t *testing.T) { } buffer := buf.New() - client := NewClientSession(protocol.DefaultIDHash, context.TODO()) + client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() @@ -134,7 +134,7 @@ func TestMuxRequest(t *testing.T) { } buffer := buf.New() - client := NewClientSession(protocol.DefaultIDHash, context.TODO()) + client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index cbbbf585..b1592766 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -54,12 +54,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) { } // Process implements proxy.Outbound.Process(). -func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { +func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { var rec *protocol.ServerSpec var conn internet.Connection err := retry.ExponentialBackoff(5, 200).On(func() error { - rec = v.serverPicker.PickServer() + rec = h.serverPicker.PickServer() rawConn, err := dialer.Dial(ctx, rec.Destination()) if err != nil { return err @@ -113,10 +113,13 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte input := link.Reader output := link.Writer - ctx = context.WithValue(ctx, vmess.AlterID, len(account.AlterIDs)) + isAEAD := false + if !aead_disabled && len(account.AlterIDs) == 0 { + isAEAD = true + } - session := encoding.NewClientSession(protocol.DefaultIDHash, ctx) - sessionPolicy := v.policyManager.ForLevel(request.User.Level) + session := encoding.NewClientSession(isAEAD, protocol.DefaultIDHash, ctx) + sessionPolicy := h.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) @@ -159,7 +162,7 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if err != nil { return newError("failed to read header").Base(err) } - v.handleCommand(rec.Destination(), header.Command) + h.handleCommand(rec.Destination(), header.Command) bodyReader := session.DecodeResponseBody(request, reader) @@ -176,6 +179,7 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte var ( enablePadding = false + aead_disabled = false ) func shouldEnablePadding(s protocol.SecurityType) bool { @@ -188,8 +192,14 @@ func init() { })) const defaultFlagValue = "NOT_DEFINED_AT_ALL" + paddingValue := platform.NewEnvFlag("v2ray.vmess.padding").GetValue(func() string { return defaultFlagValue }) if paddingValue != defaultFlagValue { enablePadding = true } + + aeadDisabled := platform.NewEnvFlag("v2ray.vmess.aead.disabled").GetValue(func() string { return defaultFlagValue }) + if aeadDisabled == "true" { + aead_disabled = true + } } diff --git a/proxy/vmess/vmessCtxInterface.go b/proxy/vmess/vmessCtxInterface.go index dbfb5b72..5d26f9e5 100644 --- a/proxy/vmess/vmessCtxInterface.go +++ b/proxy/vmess/vmessCtxInterface.go @@ -1,3 +1,4 @@ package vmess +// example const AlterID = "VMessCtxInterface_AlterID" diff --git a/release/user-package.sh b/release/user-package.sh index d0bef1c4..4161d274 100755 --- a/release/user-package.sh +++ b/release/user-package.sh @@ -50,10 +50,10 @@ build_v2() { build_dat() { echo ">>> Downloading newest geoip ..." - curl -s -L -o "$TMP"/geoip.dat "https://github.com/v2ray/geoip/raw/release/geoip.dat" + curl -s -L -o "$TMP"/geoip.dat "https://github.com/v2fly/geoip/raw/release/geoip.dat" echo ">>> Downloading newest geosite ..." - curl -s -L -o "$TMP"/geosite.dat "https://github.com/v2ray/domain-list-community/raw/release/dlc.dat" + curl -s -L -o "$TMP"/geosite.dat "https://github.com/v2fly/domain-list-community/raw/release/dlc.dat" } copyconf() { diff --git a/transport/internet/domainsocket/dial.go b/transport/internet/domainsocket/dial.go index bd6591de..d3a43f48 100644 --- a/transport/internet/domainsocket/dial.go +++ b/transport/internet/domainsocket/dial.go @@ -11,6 +11,7 @@ import ( "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" + "v2ray.com/core/transport/internet/xtls" ) func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { @@ -27,6 +28,8 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { return tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest))), nil + } else if config := xtls.ConfigFromStreamSettings(streamSettings); config != nil { + return xtls.Client(conn, config.GetXTLSConfig(xtls.WithDestination(dest))), nil } return conn, nil diff --git a/transport/internet/domainsocket/listener.go b/transport/internet/domainsocket/listener.go index 9a98971e..607dbbd0 100644 --- a/transport/internet/domainsocket/listener.go +++ b/transport/internet/domainsocket/listener.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/pires/go-proxyproto" + goxtls "github.com/xtls/go" "golang.org/x/sys/unix" "v2ray.com/core/common" @@ -18,15 +19,17 @@ import ( "v2ray.com/core/common/session" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" + "v2ray.com/core/transport/internet/xtls" ) type Listener struct { - addr *net.UnixAddr - ln net.Listener - tlsConfig *gotls.Config - config *Config - addConn internet.ConnHandler - locker *fileLocker + addr *net.UnixAddr + ln net.Listener + tlsConfig *gotls.Config + xtlsConfig *goxtls.Config + config *Config + addConn internet.ConnHandler + locker *fileLocker } func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { @@ -73,6 +76,9 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { ln.tlsConfig = config.GetTLSConfig() } + if config := xtls.ConfigFromStreamSettings(streamSettings); config != nil { + ln.xtlsConfig = config.GetXTLSConfig() + } go ln.run() @@ -103,6 +109,8 @@ func (ln *Listener) run() { if ln.tlsConfig != nil { conn = tls.Server(conn, ln.tlsConfig) + } else if ln.xtlsConfig != nil { + conn = xtls.Server(conn, ln.xtlsConfig) } ln.addConn(internet.Connection(conn)) diff --git a/transport/internet/kcp/config.go b/transport/internet/kcp/config.go index 0abd12d5..f6016185 100644 --- a/transport/internet/kcp/config.go +++ b/transport/internet/kcp/config.go @@ -4,8 +4,6 @@ package kcp import ( "crypto/cipher" - "fmt" - "v2ray.com/core/common" "v2ray.com/core/transport/internet" ) @@ -63,7 +61,6 @@ func (c *Config) GetReadBufferSize() uint32 { // GetSecurity returns the security settings. func (c *Config) GetSecurity() (cipher.AEAD, error) { if c.Seed != nil { - fmt.Println("=========NewAEADAESGCMBasedOnSeed Used============") return NewAEADAESGCMBasedOnSeed(c.Seed.Seed), nil } return NewSimpleAuthenticator(), nil diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index e6f414e0..6c744dea 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -10,6 +10,7 @@ import ( "v2ray.com/core/common/session" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" + "v2ray.com/core/transport/internet/xtls" ) // Dial dials a new TCP connection to the given destination. @@ -30,6 +31,9 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } */ conn = tls.Client(conn, tlsConfig) + } else if config := xtls.ConfigFromStreamSettings(streamSettings); config != nil { + xtlsConfig := config.GetXTLSConfig(xtls.WithDestination(dest)) + conn = xtls.Client(conn, xtlsConfig) } tcpSettings := streamSettings.ProtocolSettings.(*Config) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 7813c5e4..de13e76f 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -9,18 +9,21 @@ import ( "time" "github.com/pires/go-proxyproto" + goxtls "github.com/xtls/go" "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/common/session" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" + "v2ray.com/core/transport/internet/xtls" ) // Listener is an internet.Listener that listens for TCP connections. type Listener struct { listener net.Listener tlsConfig *gotls.Config + xtlsConfig *goxtls.Config authConfig internet.ConnectionAuthenticator config *Config addConn internet.ConnHandler @@ -59,6 +62,9 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2")) } + if config := xtls.ConfigFromStreamSettings(streamSettings); config != nil { + l.xtlsConfig = config.GetXTLSConfig(xtls.WithNextProto("h2")) + } if tcpSettings.HeaderSettings != nil { headerConfig, err := tcpSettings.HeaderSettings.GetInstance() @@ -93,6 +99,8 @@ func (v *Listener) keepAccepting() { if v.tlsConfig != nil { conn = tls.Server(conn, v.tlsConfig) + } else if v.xtlsConfig != nil { + conn = xtls.Server(conn, v.xtlsConfig) } if v.authConfig != nil { conn = v.authConfig.Server(conn) diff --git a/transport/internet/xtls/config.go b/transport/internet/xtls/config.go new file mode 100644 index 00000000..580d1cea --- /dev/null +++ b/transport/internet/xtls/config.go @@ -0,0 +1,231 @@ +// +build !confonly + +package xtls + +import ( + "crypto/x509" + "sync" + "time" + + xtls "github.com/xtls/go" + + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol/tls/cert" + "v2ray.com/core/transport/internet" +) + +var ( + globalSessionCache = xtls.NewLRUClientSessionCache(128) +) + +// ParseCertificate converts a cert.Certificate to Certificate. +func ParseCertificate(c *cert.Certificate) *Certificate { + certPEM, keyPEM := c.ToPEM() + return &Certificate{ + Certificate: certPEM, + Key: keyPEM, + } +} + +func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { + root := x509.NewCertPool() + for _, cert := range c.Certificate { + if !root.AppendCertsFromPEM(cert.Certificate) { + return nil, newError("failed to append cert").AtWarning() + } + } + return root, nil +} + +// BuildCertificates builds a list of TLS certificates from proto definition. +func (c *Config) BuildCertificates() []xtls.Certificate { + certs := make([]xtls.Certificate, 0, len(c.Certificate)) + for _, entry := range c.Certificate { + if entry.Usage != Certificate_ENCIPHERMENT { + continue + } + keyPair, err := xtls.X509KeyPair(entry.Certificate, entry.Key) + if err != nil { + newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() + continue + } + certs = append(certs, keyPair) + } + return certs +} + +func isCertificateExpired(c *xtls.Certificate) bool { + if c.Leaf == nil && len(c.Certificate) > 0 { + if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil { + c.Leaf = pc + } + } + + // If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate. + return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(-time.Minute)) +} + +func issueCertificate(rawCA *Certificate, domain string) (*xtls.Certificate, error) { + parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key) + if err != nil { + return nil, newError("failed to parse raw certificate").Base(err) + } + newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain)) + if err != nil { + return nil, newError("failed to generate new certificate for ", domain).Base(err) + } + newCertPEM, newKeyPEM := newCert.ToPEM() + cert, err := xtls.X509KeyPair(newCertPEM, newKeyPEM) + return &cert, err +} + +func (c *Config) getCustomCA() []*Certificate { + certs := make([]*Certificate, 0, len(c.Certificate)) + for _, certificate := range c.Certificate { + if certificate.Usage == Certificate_AUTHORITY_ISSUE { + certs = append(certs, certificate) + } + } + return certs +} + +func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) { + var access sync.RWMutex + + return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) { + domain := hello.ServerName + certExpired := false + + access.RLock() + certificate, found := c.NameToCertificate[domain] + access.RUnlock() + + if found { + if !isCertificateExpired(certificate) { + return certificate, nil + } + certExpired = true + } + + if certExpired { + newCerts := make([]xtls.Certificate, 0, len(c.Certificates)) + + access.Lock() + for _, certificate := range c.Certificates { + if !isCertificateExpired(&certificate) { + newCerts = append(newCerts, certificate) + } + } + + c.Certificates = newCerts + access.Unlock() + } + + var issuedCertificate *xtls.Certificate + + // Create a new certificate from existing CA if possible + for _, rawCert := range ca { + if rawCert.Usage == Certificate_AUTHORITY_ISSUE { + newCert, err := issueCertificate(rawCert, domain) + if err != nil { + newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() + continue + } + + access.Lock() + c.Certificates = append(c.Certificates, *newCert) + issuedCertificate = &c.Certificates[len(c.Certificates)-1] + access.Unlock() + break + } + } + + if issuedCertificate == nil { + return nil, newError("failed to create a new certificate for ", domain) + } + + access.Lock() + c.BuildNameToCertificate() + access.Unlock() + + return issuedCertificate, nil + } +} + +func (c *Config) parseServerName() string { + return c.ServerName +} + +// GetXTLSConfig converts this Config into xtls.Config. +func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config { + root, err := c.getCertPool() + if err != nil { + newError("failed to load system root certificate").AtError().Base(err).WriteToLog() + } + + config := &xtls.Config{ + ClientSessionCache: globalSessionCache, + RootCAs: root, + InsecureSkipVerify: c.AllowInsecure, + NextProtos: c.NextProtocol, + SessionTicketsDisabled: c.DisableSessionResumption, + } + if c == nil { + return config + } + + for _, opt := range opts { + opt(config) + } + + config.Certificates = c.BuildCertificates() + config.BuildNameToCertificate() + + caCerts := c.getCustomCA() + if len(caCerts) > 0 { + config.GetCertificate = getGetCertificateFunc(config, caCerts) + } + + if sn := c.parseServerName(); len(sn) > 0 { + config.ServerName = sn + } + + if len(config.NextProtos) == 0 { + config.NextProtos = []string{"h2", "http/1.1"} + } + + return config +} + +// Option for building XTLS config. +type Option func(*xtls.Config) + +// WithDestination sets the server name in XTLS config. +func WithDestination(dest net.Destination) Option { + return func(config *xtls.Config) { + if dest.Address.Family().IsDomain() && config.ServerName == "" { + config.ServerName = dest.Address.Domain() + } + } +} + +// WithNextProto sets the ALPN values in XTLS config. +func WithNextProto(protocol ...string) Option { + return func(config *xtls.Config) { + if len(config.NextProtos) == 0 { + config.NextProtos = protocol + } + } +} + +// ConfigFromStreamSettings fetches Config from stream settings. Nil if not found. +func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config { + if settings == nil { + return nil + } + config, ok := settings.SecuritySettings.(*Config) + if !ok { + return nil + } + return config +} diff --git a/transport/internet/xtls/config.pb.go b/transport/internet/xtls/config.pb.go new file mode 100644 index 00000000..6ec1e71e --- /dev/null +++ b/transport/internet/xtls/config.pb.go @@ -0,0 +1,378 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: transport/internet/xtls/config.proto + +package xtls + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type Certificate_Usage int32 + +const ( + Certificate_ENCIPHERMENT Certificate_Usage = 0 + Certificate_AUTHORITY_VERIFY Certificate_Usage = 1 + Certificate_AUTHORITY_ISSUE Certificate_Usage = 2 +) + +// Enum value maps for Certificate_Usage. +var ( + Certificate_Usage_name = map[int32]string{ + 0: "ENCIPHERMENT", + 1: "AUTHORITY_VERIFY", + 2: "AUTHORITY_ISSUE", + } + Certificate_Usage_value = map[string]int32{ + "ENCIPHERMENT": 0, + "AUTHORITY_VERIFY": 1, + "AUTHORITY_ISSUE": 2, + } +) + +func (x Certificate_Usage) Enum() *Certificate_Usage { + p := new(Certificate_Usage) + *p = x + return p +} + +func (x Certificate_Usage) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Certificate_Usage) Descriptor() protoreflect.EnumDescriptor { + return file_transport_internet_xtls_config_proto_enumTypes[0].Descriptor() +} + +func (Certificate_Usage) Type() protoreflect.EnumType { + return &file_transport_internet_xtls_config_proto_enumTypes[0] +} + +func (x Certificate_Usage) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Certificate_Usage.Descriptor instead. +func (Certificate_Usage) EnumDescriptor() ([]byte, []int) { + return file_transport_internet_xtls_config_proto_rawDescGZIP(), []int{0, 0} +} + +type Certificate struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // XTLS certificate in x509 format. + Certificate []byte `protobuf:"bytes,1,opt,name=Certificate,proto3" json:"Certificate,omitempty"` + // XTLS key in x509 format. + Key []byte `protobuf:"bytes,2,opt,name=Key,proto3" json:"Key,omitempty"` + Usage Certificate_Usage `protobuf:"varint,3,opt,name=usage,proto3,enum=v2ray.core.transport.internet.xtls.Certificate_Usage" json:"usage,omitempty"` +} + +func (x *Certificate) Reset() { + *x = Certificate{} + if protoimpl.UnsafeEnabled { + mi := &file_transport_internet_xtls_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Certificate) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Certificate) ProtoMessage() {} + +func (x *Certificate) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_xtls_config_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Certificate.ProtoReflect.Descriptor instead. +func (*Certificate) Descriptor() ([]byte, []int) { + return file_transport_internet_xtls_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Certificate) GetCertificate() []byte { + if x != nil { + return x.Certificate + } + return nil +} + +func (x *Certificate) GetKey() []byte { + if x != nil { + return x.Key + } + return nil +} + +func (x *Certificate) GetUsage() Certificate_Usage { + if x != nil { + return x.Usage + } + return Certificate_ENCIPHERMENT +} + +type Config struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Whether or not to allow self-signed certificates. + AllowInsecure bool `protobuf:"varint,1,opt,name=allow_insecure,json=allowInsecure,proto3" json:"allow_insecure,omitempty"` + // Whether or not to allow insecure cipher suites. + AllowInsecureCiphers bool `protobuf:"varint,5,opt,name=allow_insecure_ciphers,json=allowInsecureCiphers,proto3" json:"allow_insecure_ciphers,omitempty"` + // List of certificates to be served on server. + Certificate []*Certificate `protobuf:"bytes,2,rep,name=certificate,proto3" json:"certificate,omitempty"` + // Override server name. + ServerName string `protobuf:"bytes,3,opt,name=server_name,json=serverName,proto3" json:"server_name,omitempty"` + // Lists of string as ALPN values. + NextProtocol []string `protobuf:"bytes,4,rep,name=next_protocol,json=nextProtocol,proto3" json:"next_protocol,omitempty"` + // Whether or not to disable session (ticket) resumption. + DisableSessionResumption bool `protobuf:"varint,6,opt,name=disable_session_resumption,json=disableSessionResumption,proto3" json:"disable_session_resumption,omitempty"` + // If true, root certificates on the system will not be loaded for verification. + DisableSystemRoot bool `protobuf:"varint,7,opt,name=disable_system_root,json=disableSystemRoot,proto3" json:"disable_system_root,omitempty"` +} + +func (x *Config) Reset() { + *x = Config{} + if protoimpl.UnsafeEnabled { + mi := &file_transport_internet_xtls_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_xtls_config_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_xtls_config_proto_rawDescGZIP(), []int{1} +} + +func (x *Config) GetAllowInsecure() bool { + if x != nil { + return x.AllowInsecure + } + return false +} + +func (x *Config) GetAllowInsecureCiphers() bool { + if x != nil { + return x.AllowInsecureCiphers + } + return false +} + +func (x *Config) GetCertificate() []*Certificate { + if x != nil { + return x.Certificate + } + return nil +} + +func (x *Config) GetServerName() string { + if x != nil { + return x.ServerName + } + return "" +} + +func (x *Config) GetNextProtocol() []string { + if x != nil { + return x.NextProtocol + } + return nil +} + +func (x *Config) GetDisableSessionResumption() bool { + if x != nil { + return x.DisableSessionResumption + } + return false +} + +func (x *Config) GetDisableSystemRoot() bool { + if x != nil { + return x.DisableSystemRoot + } + return false +} + +var File_transport_internet_xtls_config_proto protoreflect.FileDescriptor + +var file_transport_internet_xtls_config_proto_rawDesc = []byte{ + 0x0a, 0x24, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x65, 0x74, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x22, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, + 0x72, 0x65, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x78, 0x74, 0x6c, 0x73, 0x22, 0xd4, 0x01, 0x0a, 0x0b, 0x43, + 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x65, + 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x0b, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x03, 0x4b, 0x65, 0x79, 0x12, 0x4b, + 0x0a, 0x05, 0x75, 0x73, 0x61, 0x67, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x35, 0x2e, + 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, + 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x78, 0x74, + 0x6c, 0x73, 0x2e, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x2e, 0x55, + 0x73, 0x61, 0x67, 0x65, 0x52, 0x05, 0x75, 0x73, 0x61, 0x67, 0x65, 0x22, 0x44, 0x0a, 0x05, 0x55, + 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x4e, 0x43, 0x49, 0x50, 0x48, 0x45, 0x52, + 0x4d, 0x45, 0x4e, 0x54, 0x10, 0x00, 0x12, 0x14, 0x0a, 0x10, 0x41, 0x55, 0x54, 0x48, 0x4f, 0x52, + 0x49, 0x54, 0x59, 0x5f, 0x56, 0x45, 0x52, 0x49, 0x46, 0x59, 0x10, 0x01, 0x12, 0x13, 0x0a, 0x0f, + 0x41, 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x54, 0x59, 0x5f, 0x49, 0x53, 0x53, 0x55, 0x45, 0x10, + 0x02, 0x22, 0xec, 0x02, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x25, 0x0a, 0x0e, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x6e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x49, 0x6e, 0x73, 0x65, 0x63, + 0x75, 0x72, 0x65, 0x12, 0x34, 0x0a, 0x16, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x6e, 0x73, + 0x65, 0x63, 0x75, 0x72, 0x65, 0x5f, 0x63, 0x69, 0x70, 0x68, 0x65, 0x72, 0x73, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x14, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x49, 0x6e, 0x73, 0x65, 0x63, 0x75, + 0x72, 0x65, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x73, 0x12, 0x51, 0x0a, 0x0b, 0x63, 0x65, 0x72, + 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2f, + 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x78, + 0x74, 0x6c, 0x73, 0x2e, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, + 0x0b, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x0a, 0x0b, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x23, 0x0a, + 0x0d, 0x6e, 0x65, 0x78, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x6e, 0x65, 0x78, 0x74, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x3c, 0x0a, 0x1a, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x73, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x65, 0x73, 0x75, 0x6d, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x18, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x75, 0x6d, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x2e, 0x0a, 0x13, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x73, 0x79, 0x73, 0x74, + 0x65, 0x6d, 0x5f, 0x72, 0x6f, 0x6f, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x52, 0x6f, 0x6f, 0x74, + 0x42, 0x77, 0x0a, 0x26, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, + 0x72, 0x65, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x78, 0x74, 0x6c, 0x73, 0x50, 0x01, 0x5a, 0x26, 0x76, 0x32, + 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2f, + 0x78, 0x74, 0x6c, 0x73, 0xaa, 0x02, 0x22, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x72, + 0x65, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x58, 0x74, 0x6c, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, +} + +var ( + file_transport_internet_xtls_config_proto_rawDescOnce sync.Once + file_transport_internet_xtls_config_proto_rawDescData = file_transport_internet_xtls_config_proto_rawDesc +) + +func file_transport_internet_xtls_config_proto_rawDescGZIP() []byte { + file_transport_internet_xtls_config_proto_rawDescOnce.Do(func() { + file_transport_internet_xtls_config_proto_rawDescData = protoimpl.X.CompressGZIP(file_transport_internet_xtls_config_proto_rawDescData) + }) + return file_transport_internet_xtls_config_proto_rawDescData +} + +var file_transport_internet_xtls_config_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_transport_internet_xtls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_transport_internet_xtls_config_proto_goTypes = []interface{}{ + (Certificate_Usage)(0), // 0: v2ray.core.transport.internet.xtls.Certificate.Usage + (*Certificate)(nil), // 1: v2ray.core.transport.internet.xtls.Certificate + (*Config)(nil), // 2: v2ray.core.transport.internet.xtls.Config +} +var file_transport_internet_xtls_config_proto_depIdxs = []int32{ + 0, // 0: v2ray.core.transport.internet.xtls.Certificate.usage:type_name -> v2ray.core.transport.internet.xtls.Certificate.Usage + 1, // 1: v2ray.core.transport.internet.xtls.Config.certificate:type_name -> v2ray.core.transport.internet.xtls.Certificate + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_transport_internet_xtls_config_proto_init() } +func file_transport_internet_xtls_config_proto_init() { + if File_transport_internet_xtls_config_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_transport_internet_xtls_config_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Certificate); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_transport_internet_xtls_config_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Config); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_transport_internet_xtls_config_proto_rawDesc, + NumEnums: 1, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_xtls_config_proto_goTypes, + DependencyIndexes: file_transport_internet_xtls_config_proto_depIdxs, + EnumInfos: file_transport_internet_xtls_config_proto_enumTypes, + MessageInfos: file_transport_internet_xtls_config_proto_msgTypes, + }.Build() + File_transport_internet_xtls_config_proto = out.File + file_transport_internet_xtls_config_proto_rawDesc = nil + file_transport_internet_xtls_config_proto_goTypes = nil + file_transport_internet_xtls_config_proto_depIdxs = nil +} diff --git a/transport/internet/xtls/config.proto b/transport/internet/xtls/config.proto new file mode 100644 index 00000000..9ad77221 --- /dev/null +++ b/transport/internet/xtls/config.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package v2ray.core.transport.internet.xtls; +option csharp_namespace = "V2Ray.Core.Transport.Internet.Xtls"; +option go_package = "v2ray.com/core/transport/internet/xtls"; +option java_package = "com.v2ray.core.transport.internet.xtls"; +option java_multiple_files = true; + +message Certificate { + // XTLS certificate in x509 format. + bytes Certificate = 1; + + // XTLS key in x509 format. + bytes Key = 2; + + enum Usage { + ENCIPHERMENT = 0; + AUTHORITY_VERIFY = 1; + AUTHORITY_ISSUE = 2; + } + + Usage usage = 3; +} + +message Config { + // Whether or not to allow self-signed certificates. + bool allow_insecure = 1; + + // Whether or not to allow insecure cipher suites. + bool allow_insecure_ciphers = 5; + + // List of certificates to be served on server. + repeated Certificate certificate = 2; + + // Override server name. + string server_name = 3; + + // Lists of string as ALPN values. + repeated string next_protocol = 4; + + // Whether or not to disable session (ticket) resumption. + bool disable_session_resumption = 6; + + // If true, root certificates on the system will not be loaded for verification. + bool disable_system_root = 7; +} diff --git a/transport/internet/xtls/config_other.go b/transport/internet/xtls/config_other.go new file mode 100644 index 00000000..a1dda046 --- /dev/null +++ b/transport/internet/xtls/config_other.go @@ -0,0 +1,53 @@ +// +build !windows +// +build !confonly + +package xtls + +import ( + "crypto/x509" + "sync" +) + +type rootCertsCache struct { + sync.Mutex + pool *x509.CertPool +} + +func (c *rootCertsCache) load() (*x509.CertPool, error) { + c.Lock() + defer c.Unlock() + + if c.pool != nil { + return c.pool, nil + } + + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + c.pool = pool + return pool, nil +} + +var rootCerts rootCertsCache + +func (c *Config) getCertPool() (*x509.CertPool, error) { + if c.DisableSystemRoot { + return c.loadSelfCertPool() + } + + if len(c.Certificate) == 0 { + return rootCerts.load() + } + + pool, err := x509.SystemCertPool() + if err != nil { + return nil, newError("system root").AtWarning().Base(err) + } + for _, cert := range c.Certificate { + if !pool.AppendCertsFromPEM(cert.Certificate) { + return nil, newError("append cert to root").AtWarning().Base(err) + } + } + return pool, err +} diff --git a/transport/internet/xtls/config_test.go b/transport/internet/xtls/config_test.go new file mode 100644 index 00000000..9e7227c9 --- /dev/null +++ b/transport/internet/xtls/config_test.go @@ -0,0 +1,100 @@ +package xtls_test + +import ( + "crypto/x509" + "testing" + "time" + + xtls "github.com/xtls/go" + + "v2ray.com/core/common" + "v2ray.com/core/common/protocol/tls/cert" + . "v2ray.com/core/transport/internet/xtls" +) + +func TestCertificateIssuing(t *testing.T) { + certificate := ParseCertificate(cert.MustGenerate(nil, cert.Authority(true), cert.KeyUsage(x509.KeyUsageCertSign))) + certificate.Usage = Certificate_AUTHORITY_ISSUE + + c := &Config{ + Certificate: []*Certificate{ + certificate, + }, + } + + xtlsConfig := c.GetXTLSConfig() + v2rayCert, err := xtlsConfig.GetCertificate(&xtls.ClientHelloInfo{ + ServerName: "www.v2fly.org", + }) + common.Must(err) + + x509Cert, err := x509.ParseCertificate(v2rayCert.Certificate[0]) + common.Must(err) + if !x509Cert.NotAfter.After(time.Now()) { + t.Error("NotAfter: ", x509Cert.NotAfter) + } +} + +func TestExpiredCertificate(t *testing.T) { + caCert := cert.MustGenerate(nil, cert.Authority(true), cert.KeyUsage(x509.KeyUsageCertSign)) + expiredCert := cert.MustGenerate(caCert, cert.NotAfter(time.Now().Add(time.Minute*-2)), cert.CommonName("www.v2fly.org"), cert.DNSNames("www.v2fly.org")) + + certificate := ParseCertificate(caCert) + certificate.Usage = Certificate_AUTHORITY_ISSUE + + certificate2 := ParseCertificate(expiredCert) + + c := &Config{ + Certificate: []*Certificate{ + certificate, + certificate2, + }, + } + + xtlsConfig := c.GetXTLSConfig() + v2rayCert, err := xtlsConfig.GetCertificate(&xtls.ClientHelloInfo{ + ServerName: "www.v2fly.org", + }) + common.Must(err) + + x509Cert, err := x509.ParseCertificate(v2rayCert.Certificate[0]) + common.Must(err) + if !x509Cert.NotAfter.After(time.Now()) { + t.Error("NotAfter: ", x509Cert.NotAfter) + } +} + +func TestInsecureCertificates(t *testing.T) { + c := &Config{ + AllowInsecureCiphers: true, + } + + xtlsConfig := c.GetXTLSConfig() + if len(xtlsConfig.CipherSuites) > 0 { + t.Fatal("Unexpected tls cipher suites list: ", xtlsConfig.CipherSuites) + } +} + +func BenchmarkCertificateIssuing(b *testing.B) { + certificate := ParseCertificate(cert.MustGenerate(nil, cert.Authority(true), cert.KeyUsage(x509.KeyUsageCertSign))) + certificate.Usage = Certificate_AUTHORITY_ISSUE + + c := &Config{ + Certificate: []*Certificate{ + certificate, + }, + } + + xtlsConfig := c.GetXTLSConfig() + lenCerts := len(xtlsConfig.Certificates) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = xtlsConfig.GetCertificate(&xtls.ClientHelloInfo{ + ServerName: "www.v2fly.org", + }) + delete(xtlsConfig.NameToCertificate, "www.v2fly.org") + xtlsConfig.Certificates = xtlsConfig.Certificates[:lenCerts] + } +} diff --git a/transport/internet/xtls/config_windows.go b/transport/internet/xtls/config_windows.go new file mode 100644 index 00000000..8c5bf01d --- /dev/null +++ b/transport/internet/xtls/config_windows.go @@ -0,0 +1,14 @@ +// +build windows +// +build !confonly + +package xtls + +import "crypto/x509" + +func (c *Config) getCertPool() (*x509.CertPool, error) { + if c.DisableSystemRoot { + return c.loadSelfCertPool() + } + + return nil, nil +} diff --git a/transport/internet/xtls/errors.generated.go b/transport/internet/xtls/errors.generated.go new file mode 100644 index 00000000..9269f558 --- /dev/null +++ b/transport/internet/xtls/errors.generated.go @@ -0,0 +1,9 @@ +package xtls + +import "v2ray.com/core/common/errors" + +type errPathObjHolder struct{} + +func newError(values ...interface{}) *errors.Error { + return errors.New(values...).WithPathObj(errPathObjHolder{}) +} diff --git a/transport/internet/xtls/xtls.go b/transport/internet/xtls/xtls.go new file mode 100644 index 00000000..e34408bf --- /dev/null +++ b/transport/internet/xtls/xtls.go @@ -0,0 +1,50 @@ +// +build !confonly + +package xtls + +import ( + xtls "github.com/xtls/go" + + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" +) + +//go:generate errorgen + +var ( + _ buf.Writer = (*Conn)(nil) +) + +type Conn struct { + *xtls.Conn +} + +func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error { + mb = buf.Compact(mb) + mb, err := buf.WriteMultiBuffer(c, mb) + buf.ReleaseMulti(mb) + return err +} + +func (c *Conn) HandshakeAddress() net.Address { + if err := c.Handshake(); err != nil { + return nil + } + state := c.ConnectionState() + if state.ServerName == "" { + return nil + } + return net.ParseAddress(state.ServerName) +} + +// Client initiates a XTLS client handshake on the given connection. +func Client(c net.Conn, config *xtls.Config) net.Conn { + xtlsConn := xtls.Client(c, config) + return &Conn{Conn: xtlsConn} +} + +// Server initiates a XTLS server handshake on the given connection. +func Server(c net.Conn, config *xtls.Config) net.Conn { + xtlsConn := xtls.Server(c, config) + return &Conn{Conn: xtlsConn} +}