diff --git a/common/protocol/server_picker.go b/common/protocol/server_picker.go deleted file mode 100644 index 62aa404e..00000000 --- a/common/protocol/server_picker.go +++ /dev/null @@ -1,89 +0,0 @@ -package protocol - -import ( - "sync" -) - -type ServerList struct { - sync.RWMutex - servers []*ServerSpec -} - -func NewServerList() *ServerList { - return &ServerList{} -} - -func (sl *ServerList) AddServer(server *ServerSpec) { - sl.Lock() - defer sl.Unlock() - - sl.servers = append(sl.servers, server) -} - -func (sl *ServerList) Size() uint32 { - sl.RLock() - defer sl.RUnlock() - - return uint32(len(sl.servers)) -} - -func (sl *ServerList) GetServer(idx uint32) *ServerSpec { - sl.Lock() - defer sl.Unlock() - - for { - if idx >= uint32(len(sl.servers)) { - return nil - } - - server := sl.servers[idx] - if !server.IsValid() { - sl.removeServer(idx) - continue - } - - return server - } -} - -func (sl *ServerList) removeServer(idx uint32) { - n := len(sl.servers) - sl.servers[idx] = sl.servers[n-1] - sl.servers = sl.servers[:n-1] -} - -type ServerPicker interface { - PickServer() *ServerSpec -} - -type RoundRobinServerPicker struct { - sync.Mutex - serverlist *ServerList - nextIndex uint32 -} - -func NewRoundRobinServerPicker(serverlist *ServerList) *RoundRobinServerPicker { - return &RoundRobinServerPicker{ - serverlist: serverlist, - nextIndex: 0, - } -} - -func (p *RoundRobinServerPicker) PickServer() *ServerSpec { - p.Lock() - defer p.Unlock() - - next := p.nextIndex - server := p.serverlist.GetServer(next) - if server == nil { - next = 0 - server = p.serverlist.GetServer(0) - } - next++ - if next >= p.serverlist.Size() { - next = 0 - } - p.nextIndex = next - - return server -} diff --git a/common/protocol/server_picker_test.go b/common/protocol/server_picker_test.go deleted file mode 100644 index 8919b10b..00000000 --- a/common/protocol/server_picker_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package protocol_test - -import ( - "testing" - "time" - - "github.com/xtls/xray-core/common/net" - . "github.com/xtls/xray-core/common/protocol" -) - -func TestServerList(t *testing.T) { - list := NewServerList() - list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid())) - if list.Size() != 1 { - t.Error("list size: ", list.Size()) - } - list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second)))) - if list.Size() != 2 { - t.Error("list.size: ", list.Size()) - } - - server := list.GetServer(1) - if server.Destination().Port != 2 { - t.Error("server: ", server.Destination()) - } - time.Sleep(2 * time.Second) - server = list.GetServer(1) - if server != nil { - t.Error("server: ", server) - } - - server = list.GetServer(0) - if server.Destination().Port != 1 { - t.Error("server: ", server.Destination()) - } -} - -func TestServerPicker(t *testing.T) { - list := NewServerList() - list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid())) - list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second)))) - list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(3)), BeforeTime(time.Now().Add(time.Second)))) - - picker := NewRoundRobinServerPicker(list) - server := picker.PickServer() - if server.Destination().Port != 1 { - t.Error("server: ", server.Destination()) - } - server = picker.PickServer() - if server.Destination().Port != 2 { - t.Error("server: ", server.Destination()) - } - server = picker.PickServer() - if server.Destination().Port != 3 { - t.Error("server: ", server.Destination()) - } - server = picker.PickServer() - if server.Destination().Port != 1 { - t.Error("server: ", server.Destination()) - } - - time.Sleep(2 * time.Second) - server = picker.PickServer() - if server.Destination().Port != 1 { - t.Error("server: ", server.Destination()) - } - server = picker.PickServer() - if server.Destination().Port != 1 { - t.Error("server: ", server.Destination()) - } -} diff --git a/common/protocol/server_spec.go b/common/protocol/server_spec.go index 24b778bb..63349af1 100644 --- a/common/protocol/server_spec.go +++ b/common/protocol/server_spec.go @@ -4,7 +4,6 @@ import ( "sync" "time" - "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/net" ) @@ -98,6 +97,8 @@ func (s *ServerSpec) AddUser(user *MemoryUser) { s.users = append(s.users, user) } +// Locking it only using the first user when user(s) exists. +// Should change after refactor func (s *ServerSpec) PickUser() *MemoryUser { s.RLock() defer s.RUnlock() @@ -106,10 +107,8 @@ func (s *ServerSpec) PickUser() *MemoryUser { switch userCount { case 0: return nil - case 1: - return s.users[0] default: - return s.users[dice.Roll(userCount)] + return s.users[0] } } diff --git a/proxy/http/client.go b/proxy/http/client.go index b1326bec..378fa9f1 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -31,7 +31,7 @@ import ( ) type Client struct { - serverPicker protocol.ServerPicker + server *protocol.ServerSpec policyManager policy.Manager header []*Header } @@ -48,21 +48,19 @@ var ( // NewClient create a new http client based on the given config. func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { - serverList := protocol.NewServerList() - for _, rec := range config.Server { - s, err := protocol.NewServerSpecFromPB(rec) - if err != nil { - return nil, errors.New("failed to get server spec").Base(err) - } - serverList.AddServer(s) + if len(config.Server) != 1 { + return nil, errors.New(`only one target server allowed`) } - if serverList.Size() == 0 { - return nil, errors.New("0 target server") + // Harcoded [0] for processing compatibility. + // Should change after refactor. + server, err := protocol.NewServerSpecFromPB(config.Server[0]) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err) } v := core.MustFromContext(ctx) return &Client{ - serverPicker: protocol.NewRoundRobinServerPicker(serverList), + server: server, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), header: config.Header, }, nil @@ -102,9 +100,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } if err := retry.ExponentialBackoff(5, 100).On(func() error { - server := c.serverPicker.PickServer() - dest := server.Destination() - user = server.PickUser() + dest := c.server.Destination() + user = c.server.PickUser() netConn, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, header, firstPayload) if netConn != nil { diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index c2ad1c1c..f7cf0441 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -22,27 +22,25 @@ import ( // Client is a inbound handler for Shadowsocks protocol type Client struct { - serverPicker protocol.ServerPicker + server *protocol.ServerSpec policyManager policy.Manager } // NewClient create a new Shadowsocks client. func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { - serverList := protocol.NewServerList() - for _, rec := range config.Server { - s, err := protocol.NewServerSpecFromPB(rec) - if err != nil { - return nil, errors.New("failed to parse server spec").Base(err) - } - serverList.AddServer(s) + if len(config.Server) != 1 { + return nil, errors.New(`only one target server allowed`) } - if serverList.Size() == 0 { - return nil, errors.New("0 server") + // Harcoded [0] for processing compatibility. + // Should change after refactor. + server, err := protocol.NewServerSpecFromPB(config.Server[0]) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err) } v := core.MustFromContext(ctx) client := &Client{ - serverPicker: protocol.NewRoundRobinServerPicker(serverList), + server: server, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), } return client, nil @@ -60,12 +58,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter destination := ob.Target network := destination.Network - var server *protocol.ServerSpec var conn stat.Connection err := retry.ExponentialBackoff(5, 100).On(func() error { - server = c.serverPicker.PickServer() - dest := server.Destination() + dest := c.server.Destination() dest.Network = network rawConn, err := dialer.Dial(ctx, dest) if err != nil { @@ -78,7 +74,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if err != nil { return errors.New("failed to find an available destination").AtWarning().Base(err) } - errors.LogInfo(ctx, "tunneling request to ", destination, " via ", network, ":", server.Destination().NetAddr()) + errors.LogInfo(ctx, "tunneling request to ", destination, " via ", network, ":", c.server.Destination().NetAddr()) defer conn.Close() @@ -93,7 +89,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter request.Command = protocol.RequestCommandUDP } - user := server.PickUser() + user := c.server.PickUser() _, ok := user.Account.(*MemoryAccount) if !ok { return errors.New("user account is not valid") diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 232215ec..b0bf1f88 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -22,27 +22,25 @@ import ( // Client is a Socks5 client. type Client struct { - serverPicker protocol.ServerPicker + server *protocol.ServerSpec policyManager policy.Manager } // NewClient create a new Socks5 client based on the given config. func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { - serverList := protocol.NewServerList() - for _, rec := range config.Server { - s, err := protocol.NewServerSpecFromPB(rec) - if err != nil { - return nil, errors.New("failed to get server spec").Base(err) - } - serverList.AddServer(s) + if len(config.Server) != 1 { + return nil, errors.New(`only one target server allowed`) } - if serverList.Size() == 0 { - return nil, errors.New("0 target server") + // Harcoded [0] for processing compatibility. + // Should change after refactor. + server, err := protocol.NewServerSpecFromPB(config.Server[0]) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err) } v := core.MustFromContext(ctx) c := &Client{ - serverPicker: protocol.NewRoundRobinServerPicker(serverList), + server: server, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), } @@ -61,16 +59,13 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter // Destination of the inner request. destination := ob.Target - // Outbound server. - var server *protocol.ServerSpec // Outbound server's destination. var dest net.Destination // Connection to the outbound server. var conn stat.Connection if err := retry.ExponentialBackoff(5, 100).On(func() error { - server = c.serverPicker.PickServer() - dest = server.Destination() + dest = c.server.Destination() rawConn, err := dialer.Dial(ctx, dest) if err != nil { return err @@ -101,7 +96,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter request.Command = protocol.RequestCommandUDP } - user := server.PickUser() + user := c.server.PickUser() if user != nil { request.User = user p = c.policyManager.ForLevel(user.Level) diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index a667d352..5906e437 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -22,27 +22,25 @@ import ( // Client is a inbound handler for trojan protocol type Client struct { - serverPicker protocol.ServerPicker + server *protocol.ServerSpec policyManager policy.Manager } // NewClient create a new trojan client. func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { - serverList := protocol.NewServerList() - for _, rec := range config.Server { - s, err := protocol.NewServerSpecFromPB(rec) - if err != nil { - return nil, errors.New("failed to parse server spec").Base(err) - } - serverList.AddServer(s) + if len(config.Server) != 1 { + return nil, errors.New(`only one target server allowed`) } - if serverList.Size() == 0 { - return nil, errors.New("0 server") + // Harcoded [0] for processing compatibility. + // Should change after refactor. + server, err := protocol.NewServerSpecFromPB(config.Server[0]) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err) } v := core.MustFromContext(ctx) client := &Client{ - serverPicker: protocol.NewRoundRobinServerPicker(serverList), + server: server, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), } return client, nil @@ -60,12 +58,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter destination := ob.Target network := destination.Network - var server *protocol.ServerSpec var conn stat.Connection err := retry.ExponentialBackoff(5, 100).On(func() error { - server = c.serverPicker.PickServer() - rawConn, err := dialer.Dial(ctx, server.Destination()) + rawConn, err := dialer.Dial(ctx, c.server.Destination()) if err != nil { return err } @@ -76,11 +72,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if err != nil { return errors.New("failed to find an available destination").AtWarning().Base(err) } - errors.LogInfo(ctx, "tunneling request to ", destination, " via ", server.Destination().NetAddr()) + errors.LogInfo(ctx, "tunneling request to ", destination, " via ", c.server.Destination().NetAddr()) defer conn.Close() - user := server.PickUser() + user := c.server.PickUser() account, ok := user.Account.(*MemoryAccount) if !ok { return errors.New("user account is not valid") diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 30b9dcf9..6ee16982 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -47,8 +47,7 @@ func init() { // Handler is an outbound connection handler for VLess protocol. type Handler struct { - serverList *protocol.ServerList - serverPicker protocol.ServerPicker + server *protocol.ServerSpec policyManager policy.Manager cone bool encryption *encryption.ClientInstance @@ -57,24 +56,24 @@ type Handler struct { // New creates a new VLess outbound handler. func New(ctx context.Context, config *Config) (*Handler, error) { - serverList := protocol.NewServerList() - for _, rec := range config.Vnext { - s, err := protocol.NewServerSpecFromPB(rec) - if err != nil { - return nil, errors.New("failed to parse server spec").Base(err).AtError() - } - serverList.AddServer(s) + if len(config.Vnext) != 1 { + return nil, errors.New(`only one vnext allowed`) + } + // Harcoded [0] for processing compatibility. + // Should change after refactor. + server, err := protocol.NewServerSpecFromPB(config.Vnext[0]) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err).AtError() } v := core.MustFromContext(ctx) handler := &Handler{ - serverList: serverList, - serverPicker: protocol.NewRoundRobinServerPicker(serverList), + server: server, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), cone: ctx.Value("cone").(bool), } - a := handler.serverPicker.PickServer().PickUser().Account.(*vless.MemoryAccount) + a := handler.server.PickUser().Account.(*vless.MemoryAccount) if a.Encryption != "" && a.Encryption != "none" { s := strings.Split(a.Encryption, ".") var nfsPKeysBytes [][]byte @@ -125,12 +124,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } ob.Name = "vless" - var rec *protocol.ServerSpec var conn stat.Connection if err := retry.ExponentialBackoff(5, 200).On(func() error { - rec = h.serverPicker.PickServer() var err error - conn, err = dialer.Dial(ctx, rec.Destination()) + conn, err = dialer.Dial(ctx, h.server.Destination()) if err != nil { return err } @@ -145,7 +142,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte iConn = statConn.Connection } target := ob.Target - errors.LogInfo(ctx, "tunneling request to ", target, " via ", rec.Destination().NetAddr()) + errors.LogInfo(ctx, "tunneling request to ", target, " via ", h.server.Destination().NetAddr()) if h.encryption != nil { var err error @@ -172,7 +169,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte request := &protocol.RequestHeader{ Version: encoding.Version, - User: rec.PickUser(), + User: h.server.PickUser(), Command: command, Address: target.Address, Port: target.Port, diff --git a/proxy/vmess/outbound/command.go b/proxy/vmess/outbound/command.go deleted file mode 100644 index 2d4747dc..00000000 --- a/proxy/vmess/outbound/command.go +++ /dev/null @@ -1,41 +0,0 @@ -package outbound - -import ( - "time" - - "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/protocol" - "github.com/xtls/xray-core/proxy/vmess" -) - -func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { - rawAccount := &vmess.Account{ - Id: cmd.ID.String(), - SecuritySettings: &protocol.SecurityConfig{ - Type: protocol.SecurityType_AUTO, - }, - } - - account, err := rawAccount.AsAccount() - common.Must(err) - user := &protocol.MemoryUser{ - Email: "", - Level: cmd.Level, - Account: account, - } - dest := net.TCPDestination(cmd.Host, cmd.Port) - until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute) - h.serverList.AddServer(protocol.NewServerSpec(dest, protocol.BeforeTime(until), user)) -} - -func (h *Handler) handleCommand(dest net.Destination, cmd protocol.ResponseCommand) { - switch typedCommand := cmd.(type) { - case *protocol.CommandSwitchAccount: - if typedCommand.Host == nil { - typedCommand.Host = dest.Address - } - h.handleSwitchAccount(typedCommand) - default: - } -} diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 27079e03..5d8f4625 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -29,27 +29,26 @@ import ( // Handler is an outbound connection handler for VMess protocol. type Handler struct { - serverList *protocol.ServerList - serverPicker protocol.ServerPicker + server *protocol.ServerSpec policyManager policy.Manager cone bool } // New creates a new VMess outbound handler. func New(ctx context.Context, config *Config) (*Handler, error) { - serverList := protocol.NewServerList() - for _, rec := range config.Receiver { - s, err := protocol.NewServerSpecFromPB(rec) - if err != nil { - return nil, errors.New("failed to parse server spec").Base(err) - } - serverList.AddServer(s) + if len(config.Receiver) != 1 { + return nil, errors.New(`only one vnext allowed`) + } + // Harcoded [0] for processing compatibility. + // Should change after refactor. + server, err := protocol.NewServerSpecFromPB(config.Receiver[0]) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err) } v := core.MustFromContext(ctx) handler := &Handler{ - serverList: serverList, - serverPicker: protocol.NewRoundRobinServerPicker(serverList), + server: server, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), cone: ctx.Value("cone").(bool), } @@ -67,11 +66,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte ob.Name = "vmess" ob.CanSpliceCopy = 3 - var rec *protocol.ServerSpec var conn stat.Connection err := retry.ExponentialBackoff(5, 200).On(func() error { - rec = h.serverPicker.PickServer() - rawConn, err := dialer.Dial(ctx, rec.Destination()) + rawConn, err := dialer.Dial(ctx, h.server.Destination()) if err != nil { return err } @@ -85,7 +82,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte defer conn.Close() target := ob.Target - errors.LogInfo(ctx, "tunneling request to ", target, " via ", rec.Destination().NetAddr()) + errors.LogInfo(ctx, "tunneling request to ", target, " via ", h.server.Destination().NetAddr()) command := protocol.RequestCommandTCP if target.Network == net.Network_UDP { @@ -95,7 +92,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte command = protocol.RequestCommandMux } - user := rec.PickUser() + user := h.server.PickUser() request := &protocol.RequestHeader{ Version: encoding.Version, User: user, @@ -198,11 +195,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) reader := &buf.BufferedReader{Reader: buf.NewReader(conn)} - header, err := session.DecodeResponseHeader(reader) - if err != nil { - return errors.New("failed to read header").Base(err) - } - h.handleCommand(rec.Destination(), header.Command) bodyReader, err := session.DecodeResponseBody(request, reader) if err != nil {