diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 74e55ab4..e3524350 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -196,6 +196,47 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran return inboundLink, outboundLink } +func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) *transport.Link { + sessionInbound := session.InboundFromContext(ctx) + var user *protocol.MemoryUser + if sessionInbound != nil { + user = sessionInbound.User + } + + link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader} + + if user != nil && len(user.Email) > 0 { + p := d.policy.ForLevel(user.Level) + if p.Stats.UserUplink { + name := "user>>>" + user.Email + ">>>traffic>>>uplink" + if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { + link.Reader.(*buf.TimeoutWrapperReader).Counter = c + } + } + if p.Stats.UserDownlink { + name := "user>>>" + user.Email + ">>>traffic>>>downlink" + if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { + link.Writer = &SizeStatWriter{ + Counter: c, + Writer: link.Writer, + } + } + } + if p.Stats.UserOnline { + name := "user>>>" + user.Email + ">>>online" + if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil { + sessionInbounds := session.InboundFromContext(ctx) + userIP := sessionInbounds.Source.Address.String() + om.AddIP(userIP) + // log Online user with ips + // errors.LogDebug(ctx, "user>>>" + user.Email + ">>>online", om.Count(), om.List()) + } + } + } + + return link +} + func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { domain := result.Domain() if domain == "" { @@ -316,6 +357,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De content = new(session.Content) ctx = session.ContextWithContent(ctx, content) } + outbound = d.WrapLink(ctx, outbound) sniffingRequest := content.SniffingRequest if !sniffingRequest.Enabled { d.routedDispatch(ctx, outbound, destination) diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index 3e46cc6c..5cc60ad7 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/net" @@ -200,6 +201,7 @@ func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, l return w.dispatcher.DispatchLink(ctx, dest, link) } + link = w.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) w.handleInternalConn(link) return nil diff --git a/common/buf/io.go b/common/buf/io.go index e1de461d..75565e53 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -30,6 +30,7 @@ type TimeoutReader interface { type TimeoutWrapperReader struct { Reader + stats.Counter mb MultiBuffer err error done chan struct{} @@ -39,11 +40,16 @@ func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) { if r.done != nil { <-r.done r.done = nil + if r.Counter != nil { + r.Counter.Add(int64(r.mb.Len())) + } return r.mb, r.err } - r.mb = nil - r.err = nil - return r.Reader.ReadMultiBuffer() + r.mb, r.err = r.Reader.ReadMultiBuffer() + if r.Counter != nil { + r.Counter.Add(int64(r.mb.Len())) + } + return r.mb, r.err } func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) { @@ -62,6 +68,9 @@ func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (M select { case <-r.done: r.done = nil + if r.Counter != nil { + r.Counter.Add(int64(r.mb.Len())) + } return r.mb, r.err case <-timeout: return nil, nil diff --git a/common/mux/server.go b/common/mux/server.go index 12e4a68f..ac121a9f 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -4,6 +4,7 @@ import ( "context" "io" + "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -61,6 +62,7 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t if dest.Address != muxCoolAddress { return s.dispatcher.DispatchLink(ctx, dest, link) } + link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) _, err := NewServerWorker(ctx, s.dispatcher, link) return err } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index c14fe34d..90fc53e8 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -182,7 +182,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st } if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{ - Reader: &buf.TimeoutWrapperReader{Reader: reader}, + Reader: reader, Writer: writer}, ); err != nil { return errors.New("failed to dispatch request").Base(err) diff --git a/proxy/http/server.go b/proxy/http/server.go index 44f5a102..90a07b38 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -193,7 +193,7 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *buf inbound.CanSpliceCopy = 1 } if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{ - Reader: &buf.TimeoutWrapperReader{Reader: reader}, + Reader: reader, Writer: buf.NewWriter(conn)}, ); err != nil { return errors.New("failed to dispatch request").Base(err) diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 455e081f..478410f3 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -161,7 +161,7 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche inbound.CanSpliceCopy = 1 } if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{ - Reader: &buf.TimeoutWrapperReader{Reader: reader}, + Reader: reader, Writer: buf.NewWriter(conn)}, ); err != nil { return errors.New("failed to dispatch request").Base(err) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 10646b97..b1b0a916 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -563,7 +563,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s bufferWriter.SetFlushNext() if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{ - Reader: &buf.TimeoutWrapperReader{Reader: clientReader}, + Reader: clientReader, Writer: clientWriter}, ); err != nil { return errors.New("failed to dispatch request").Base(err)