diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 50f39741..3987ee53 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -54,7 +54,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin } ctx = proxy.ContextWithTarget(ctx, destination) - outbound := ray.NewRay(ctx) + outbound := ray.New(ctx) snifferList := proxyman.ProtocolSniffersFromContext(ctx) if destination.Address.Family().IsDomain() || len(snifferList) == 0 { go d.routedDispatch(ctx, outbound, destination) diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index c46312bc..a94c416d 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -87,7 +87,7 @@ var muxCoolPort = net.Port(9527) func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client, error) { ctx := proxy.ContextWithTarget(context.Background(), net.TCPDestination(muxCoolAddress, muxCoolPort)) ctx, cancel := context.WithCancel(ctx) - pipe := ray.NewRay(ctx) + pipe := ray.New(ctx) c := &Client{ sessionManager: NewSessionManager(), @@ -266,7 +266,7 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (ray.Inboun return s.dispatcher.Dispatch(ctx, dest) } - ray := ray.NewRay(ctx) + ray := ray.New(ctx) worker := &ServerWorker{ dispatcher: s.dispatcher, outboundRay: ray, diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index a7051307..3bc57ea0 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -103,7 +103,7 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn if handler != nil { newError("proxying to ", tag, " for dest ", dest).AtDebug().WithContext(ctx).WriteToLog() ctx = proxy.ContextWithTarget(ctx, dest) - stream := ray.NewRay(ctx) + stream := ray.New(ctx) go handler.Dispatch(ctx, stream) return ray.NewConnection(stream.InboundOutput(), stream.InboundInput()), nil } diff --git a/app/stats/stats.go b/app/stats/stats.go index 9177e2a2..62298238 100644 --- a/app/stats/stats.go +++ b/app/stats/stats.go @@ -18,7 +18,7 @@ func (c *Counter) Value() int64 { return atomic.LoadInt64(&c.value) } -func (c *Counter) Exchange(newValue int64) int64 { +func (c *Counter) Set(newValue int64) int64 { return atomic.SwapInt64(&c.value, newValue) } diff --git a/app/stats/stats_test.go b/app/stats/stats_test.go index 9a3a1e5c..51a145ac 100644 --- a/app/stats/stats_test.go +++ b/app/stats/stats_test.go @@ -1,10 +1,12 @@ package stats_test import ( + "context" "testing" "v2ray.com/core" . "v2ray.com/core/app/stats" + "v2ray.com/core/common" . "v2ray.com/ext/assert" ) @@ -13,3 +15,18 @@ func TestInternface(t *testing.T) { assert((*Manager)(nil), Implements, (*core.StatManager)(nil)) } + +func TestStatsCounter(t *testing.T) { + assert := With(t) + + raw, err := common.CreateObject(context.Background(), &Config{}) + assert(err, IsNil) + + m := raw.(core.StatManager) + c, err := m.RegisterCounter("test.counter") + assert(err, IsNil) + + assert(c.Add(1), Equals, int64(1)) + assert(c.Set(0), Equals, int64(1)) + assert(c.Value(), Equals, int64(0)) +} diff --git a/stats.go b/stats.go index 42eb5e74..0505c0b1 100644 --- a/stats.go +++ b/stats.go @@ -6,7 +6,7 @@ import ( type StatCounter interface { Value() int64 - Exchange(int64) int64 + Set(int64) int64 Add(int64) int64 } diff --git a/transport/internet/udp/dispatcher_test.go b/transport/internet/udp/dispatcher_test.go index e6c1958a..22f4a427 100644 --- a/transport/internet/udp/dispatcher_test.go +++ b/transport/internet/udp/dispatcher_test.go @@ -33,7 +33,7 @@ func TestSameDestinationDispatching(t *testing.T) { assert := With(t) ctx, cancel := context.WithCancel(context.Background()) - link := ray.NewRay(ctx) + link := ray.New(ctx) go func() { for { data, err := link.OutboundInput().ReadMultiBuffer() diff --git a/transport/ray/direct.go b/transport/ray/direct.go index 9be2aff6..ab48021a 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -12,11 +12,25 @@ import ( "v2ray.com/core/common/signal" ) -// NewRay creates a new Ray for direct traffic transport. -func NewRay(ctx context.Context) Ray { +type Option func(*Stream) + +type addInt64 interface { + Add(int64) int64 +} + +func WithStatCounter(c addInt64) Option { + return func(s *Stream) { + s.onDataSize = append(s.onDataSize, func(delta uint64) { + c.Add(int64(delta)) + }) + } +} + +// New creates a new Ray for direct traffic transport. +func New(ctx context.Context, opts ...Option) Ray { return &directRay{ - Input: NewStream(ctx), - Output: NewStream(ctx), + Input: NewStream(ctx, opts...), + Output: NewStream(ctx, opts...), } } @@ -60,18 +74,23 @@ type Stream struct { ctx context.Context readSignal *signal.Notifier writeSignal *signal.Notifier + onDataSize []func(uint64) close bool err bool } // NewStream creates a new Stream. -func NewStream(ctx context.Context) *Stream { - return &Stream{ +func NewStream(ctx context.Context, opts ...Option) *Stream { + s := &Stream{ ctx: ctx, readSignal: signal.NewNotifier(), writeSignal: signal.NewNotifier(), size: 0, } + for _, opt := range opts { + opt(s) + } + return s } func (s *Stream) getData() (buf.MultiBuffer, error) { @@ -201,8 +220,13 @@ func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error { s.data = buf.NewMultiBufferCap(128) } + dataSize := uint64(data.Len()) + for _, f := range s.onDataSize { + f(dataSize) + } + s.data.AppendMulti(data) - s.size += uint64(data.Len()) + s.size += dataSize s.writeSignal.Signal() return nil diff --git a/transport/ray/direct_test.go b/transport/ray/direct_test.go index 64c0ab1d..c365ac62 100644 --- a/transport/ray/direct_test.go +++ b/transport/ray/direct_test.go @@ -5,6 +5,7 @@ import ( "io" "testing" + "v2ray.com/core/app/stats" "v2ray.com/core/common/buf" . "v2ray.com/core/transport/ray" . "v2ray.com/ext/assert" @@ -47,3 +48,18 @@ func TestStreamClose(t *testing.T) { _, err = stream.ReadMultiBuffer() assert(err, Equals, io.EOF) } + +func TestStreamStatCounter(t *testing.T) { + assert := With(t) + + c := new(stats.Counter) + stream := NewStream(context.Background(), WithStatCounter(c)) + + b1 := buf.New() + b1.AppendBytes('a', 'b', 'c', 'd') + assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil) + + stream.Close() + + assert(c.Value(), Equals, int64(4)) +}