From 67f409de04bea112a095258f7d8fb0f330582a9d Mon Sep 17 00:00:00 2001 From: Ye Zhihao Date: Sat, 3 Oct 2020 03:06:32 +0800 Subject: [PATCH] Stats: Implements blocking/non-blocking messaging of Channel (#250) --- app/router/command/command.go | 13 +- app/router/command/command_test.go | 289 ++++++++++++++++------------- app/stats/channel.go | 92 +++++---- app/stats/channel_test.go | 227 +++++++++++++--------- app/stats/config.pb.go | 49 +++-- app/stats/config.proto | 6 +- app/stats/stats.go | 2 +- features/stats/stats.go | 6 +- 8 files changed, 401 insertions(+), 283 deletions(-) diff --git a/app/router/command/command.go b/app/router/command/command.go index 6add0441..378bb7f7 100644 --- a/app/router/command/command.go +++ b/app/router/command/command.go @@ -6,6 +6,7 @@ package command import ( "context" + "time" "google.golang.org/grpc" @@ -38,7 +39,8 @@ func (s *routingServer) TestRoute(ctx context.Context, request *TestRouteRequest return nil, err } if request.PublishResult && s.routingStats != nil { - s.routingStats.Publish(route) + ctx, _ := context.WithTimeout(context.Background(), 4*time.Second) // nolint: lostcancel + s.routingStats.Publish(ctx, route) } return AsProtobufMessage(request.FieldSelectors)(route), nil } @@ -55,10 +57,13 @@ func (s *routingServer) SubscribeRoutingStats(request *SubscribeRoutingStatsRequ defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber) // nolint: errcheck for { select { - case value, received := <-subscriber: + case value, ok := <-subscriber: + if !ok { + return newError("Upstream closed the subscriber channel.") + } route, ok := value.(routing.Route) - if !(received && ok) { - return newError("Receiving upstream statistics failed.") + if !ok { + return newError("Upstream sent malformed statistics.") } err := stream.Send(genMessage(route)) if err != nil { diff --git a/app/router/command/command_test.go b/app/router/command/command_test.go index d9fcf585..258ff1a0 100644 --- a/app/router/command/command_test.go +++ b/app/router/command/command_test.go @@ -21,9 +21,9 @@ import ( func TestServiceSubscribeRoutingStats(t *testing.T) { c := stats.NewChannel(&stats.ChannelConfig{ - SubscriberLimit: 1, - BufferSize: 16, - BroadcastTimeout: 100, + SubscriberLimit: 1, + BufferSize: 0, + Blocking: true, }) common.Must(c.Start()) defer c.Close() @@ -55,122 +55,138 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { // Publisher goroutine go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - for { // Wait until there's one subscriber in routing stats channel - if len(c.Subscribers()) > 0 { - break + publishTestCases := func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for { // Wait until there's one subscriber in routing stats channel + if len(c.Subscribers()) > 0 { + break + } + if ctx.Err() != nil { + return ctx.Err() + } } - if ctx.Err() != nil { - errCh <- ctx.Err() + for _, tc := range testCases { + c.Publish(context.Background(), AsRoutingRoute(tc)) + time.Sleep(time.Millisecond) } + return nil } - for _, tc := range testCases { - c.Publish(AsRoutingRoute(tc)) + + if err := publishTestCases(); err != nil { + errCh <- err } // Wait for next round of publishing <-nextPub - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - defer cancel() - for { // Wait until there's one subscriber in routing stats channel - if len(c.Subscribers()) > 0 { - break - } - if ctx.Err() != nil { - errCh <- ctx.Err() - } - } - for _, tc := range testCases { - c.Publish(AsRoutingRoute(tc)) + if err := publishTestCases(); err != nil { + errCh <- err } }() // Client goroutine go func() { + defer lis.Close() conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) if err != nil { errCh <- err + return } - defer lis.Close() defer conn.Close() client := NewRoutingServiceClient(conn) // Test retrieving all fields - streamCtx, streamClose := context.WithCancel(context.Background()) - stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{}) - if err != nil { - errCh <- err - } + testRetrievingAllFields := func() error { + streamCtx, streamClose := context.WithCancel(context.Background()) - for _, tc := range testCases { - msg, err := stream.Recv() + // Test the unsubscription of stream works well + defer func() { + streamClose() + timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second) + defer timeout() + for { // Wait until there's no subscriber in routing stats channel + if len(c.Subscribers()) == 0 { + break + } + if timeOutCtx.Err() != nil { + t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err()) + } + } + }() + + stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{}) if err != nil { - errCh <- err + return err } - if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { - t.Error(r) - } - } - // Test that double subscription will fail - errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{ - FieldSelectors: []string{"ip", "port", "domain", "outbound"}, - }) - if err != nil { - errCh <- err - } - if _, err := errStream.Recv(); err == nil { - t.Error("unexpected successful subscription") - } + for _, tc := range testCases { + msg, err := stream.Recv() + if err != nil { + return err + } + if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } - // Test the unsubscription of stream works well - streamClose() - timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second) - defer timeout() - for { // Wait until there's no subscriber in routing stats channel - if len(c.Subscribers()) == 0 { - break + // Test that double subscription will fail + errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{ + FieldSelectors: []string{"ip", "port", "domain", "outbound"}, + }) + if err != nil { + return err } - if timeOutCtx.Err() != nil { - t.Error("unexpected subscribers not decreased in channel") - errCh <- timeOutCtx.Err() + if _, err := errStream.Recv(); err == nil { + t.Error("unexpected successful subscription") } + + return nil } // Test retrieving only a subset of fields - streamCtx, streamClose = context.WithCancel(context.Background()) - stream, err = client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{ - FieldSelectors: []string{"ip", "port", "domain", "outbound"}, - }) - if err != nil { + testRetrievingSubsetOfFields := func() error { + streamCtx, streamClose := context.WithCancel(context.Background()) + defer streamClose() + stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{ + FieldSelectors: []string{"ip", "port", "domain", "outbound"}, + }) + if err != nil { + return err + } + + // Send nextPub signal to start next round of publishing + close(nextPub) + + for _, tc := range testCases { + msg, err := stream.Recv() + if err != nil { + return err + } + stat := &RoutingContext{ // Only a subset of stats is retrieved + SourceIPs: tc.SourceIPs, + TargetIPs: tc.TargetIPs, + SourcePort: tc.SourcePort, + TargetPort: tc.TargetPort, + TargetDomain: tc.TargetDomain, + OutboundGroupTags: tc.OutboundGroupTags, + OutboundTag: tc.OutboundTag, + } + if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } + + return nil + } + + if err := testRetrievingAllFields(); err != nil { errCh <- err } - - close(nextPub) // Send nextPub signal to start next round of publishing - for _, tc := range testCases { - msg, err := stream.Recv() - stat := &RoutingContext{ // Only a subset of stats is retrieved - SourceIPs: tc.SourceIPs, - TargetIPs: tc.TargetIPs, - SourcePort: tc.SourcePort, - TargetPort: tc.TargetPort, - TargetDomain: tc.TargetDomain, - OutboundGroupTags: tc.OutboundGroupTags, - OutboundTag: tc.OutboundTag, - } - if err != nil { - errCh <- err - } - if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { - t.Error(r) - } + if err := testRetrievingSubsetOfFields(); err != nil { + errCh <- err } - streamClose() - - // Client passed all tests successfully - errCh <- nil + errCh <- nil // Client passed all tests successfully }() // Wait for goroutines to complete @@ -186,9 +202,9 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { func TestSerivceTestRoute(t *testing.T) { c := stats.NewChannel(&stats.ChannelConfig{ - SubscriberLimit: 1, - BufferSize: 16, - BroadcastTimeout: 100, + SubscriberLimit: 1, + BufferSize: 16, + Blocking: true, }) common.Must(c.Start()) defer c.Close() @@ -249,11 +265,11 @@ func TestSerivceTestRoute(t *testing.T) { // Client goroutine go func() { + defer lis.Close() conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) if err != nil { errCh <- err } - defer lis.Close() defer conn.Close() client := NewRoutingServiceClient(conn) @@ -268,58 +284,69 @@ func TestSerivceTestRoute(t *testing.T) { } // Test simple TestRoute - for _, tc := range testCases { - route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc}) - if err != nil { - errCh <- err - } - if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { - t.Error(r) + testSimple := func() error { + for _, tc := range testCases { + route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc}) + if err != nil { + return err + } + if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } } + return nil } // Test TestRoute with special options - sub, err := c.Subscribe() - if err != nil { - errCh <- err - } - for _, tc := range testCases { - route, err := client.TestRoute(context.Background(), &TestRouteRequest{ - RoutingContext: tc, - FieldSelectors: []string{"ip", "port", "domain", "outbound"}, - PublishResult: true, - }) - stat := &RoutingContext{ // Only a subset of stats is retrieved - SourceIPs: tc.SourceIPs, - TargetIPs: tc.TargetIPs, - SourcePort: tc.SourcePort, - TargetPort: tc.TargetPort, - TargetDomain: tc.TargetDomain, - OutboundGroupTags: tc.OutboundGroupTags, - OutboundTag: tc.OutboundTag, - } + testOptions := func() error { + sub, err := c.Subscribe() if err != nil { - errCh <- err + return err } - if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { - t.Error(r) - } - select { // Check that routing result has been published to statistics channel - case msg, received := <-sub: - if route, ok := msg.(routing.Route); received && ok { - if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { - t.Error(r) - } - } else { - t.Error("unexpected failure in receiving published routing result") + for _, tc := range testCases { + route, err := client.TestRoute(context.Background(), &TestRouteRequest{ + RoutingContext: tc, + FieldSelectors: []string{"ip", "port", "domain", "outbound"}, + PublishResult: true, + }) + if err != nil { + return err + } + stat := &RoutingContext{ // Only a subset of stats is retrieved + SourceIPs: tc.SourceIPs, + TargetIPs: tc.TargetIPs, + SourcePort: tc.SourcePort, + TargetPort: tc.TargetPort, + TargetDomain: tc.TargetDomain, + OutboundGroupTags: tc.OutboundGroupTags, + OutboundTag: tc.OutboundTag, + } + if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + select { // Check that routing result has been published to statistics channel + case msg, received := <-sub: + if route, ok := msg.(routing.Route); received && ok { + if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { + t.Error(r) + } + } else { + t.Error("unexpected failure in receiving published routing result for testcase", tc) + } + case <-time.After(100 * time.Millisecond): + t.Error("unexpected failure in receiving published routing result", tc) } - case <-time.After(100 * time.Millisecond): - t.Error("unexpected failure in receiving published routing result") } + return nil } - // Client passed all tests successfully - errCh <- nil + if err := testSimple(); err != nil { + errCh <- err + } + if err := testOptions(); err != nil { + errCh <- err + } + errCh <- nil // Client passed all tests successfully }() // Wait for goroutines to complete diff --git a/app/stats/channel.go b/app/stats/channel.go index dd484fab..2fd54468 100644 --- a/app/stats/channel.go +++ b/app/stats/channel.go @@ -3,15 +3,15 @@ package stats import ( + "context" "sync" - "time" "v2ray.com/core/common" ) // Channel is an implementation of stats.Channel. type Channel struct { - channel chan interface{} + channel chan channelMessage subscribers []chan interface{} // Synchronization components @@ -19,28 +19,21 @@ type Channel struct { closed chan struct{} // Channel options - subscriberLimit int // Set to 0 as no subscriber limit - channelBufferSize int // Set to 0 as no buffering - broadcastTimeout time.Duration // Set to 0 as non-blocking immediate timeout + blocking bool // Set blocking state if channel buffer reaches limit + bufferSize int // Set to 0 as no buffering + subsLimit int // Set to 0 as no subscriber limit } // NewChannel creates an instance of Statistics Channel. func NewChannel(config *ChannelConfig) *Channel { return &Channel{ - channel: make(chan interface{}, config.BufferSize), - subscriberLimit: int(config.SubscriberLimit), - channelBufferSize: int(config.BufferSize), - broadcastTimeout: time.Duration(config.BroadcastTimeout+1) * time.Millisecond, + channel: make(chan channelMessage, config.BufferSize), + subsLimit: int(config.SubscriberLimit), + bufferSize: int(config.BufferSize), + blocking: config.Blocking, } } -// Channel returns the underlying go channel. -func (c *Channel) Channel() chan interface{} { - c.access.RLock() - defer c.access.RUnlock() - return c.channel -} - // Subscribers implements stats.Channel. func (c *Channel) Subscribers() []chan interface{} { c.access.RLock() @@ -52,10 +45,10 @@ func (c *Channel) Subscribers() []chan interface{} { func (c *Channel) Subscribe() (chan interface{}, error) { c.access.Lock() defer c.access.Unlock() - if c.subscriberLimit > 0 && len(c.subscribers) >= c.subscriberLimit { + if c.subsLimit > 0 && len(c.subscribers) >= c.subsLimit { return nil, newError("Number of subscribers has reached limit") } - subscriber := make(chan interface{}, c.channelBufferSize) + subscriber := make(chan interface{}, c.bufferSize) c.subscribers = append(c.subscribers, subscriber) return subscriber, nil } @@ -77,16 +70,17 @@ func (c *Channel) Unsubscribe(subscriber chan interface{}) error { } // Publish implements stats.Channel. -func (c *Channel) Publish(message interface{}) { +func (c *Channel) Publish(ctx context.Context, msg interface{}) { select { // Early exit if channel closed case <-c.closed: return default: - } - select { // Drop message if not successfully sent - case c.channel <- message: - default: - return + pub := channelMessage{context: ctx, message: msg} + if c.blocking { + pub.publish(c.channel) + } else { + pub.publishNonBlocking(c.channel) + } } } @@ -111,13 +105,12 @@ func (c *Channel) Start() error { go func() { for { select { - case message := <-c.channel: // Broadcast message - for _, sub := range c.Subscribers() { // Concurrency-safe subscribers retreivement - select { - case sub <- message: // Successfully sent message - case <-time.After(c.broadcastTimeout): // Remove timeout subscriber - common.Must(c.Unsubscribe(sub)) - close(sub) // Actively close subscriber as notification + case pub := <-c.channel: // Published message received + for _, sub := range c.Subscribers() { // Concurrency-safe subscribers retrievement + if c.blocking { + pub.broadcast(sub) + } else { + pub.broadcastNonBlocking(sub) } } case <-c.closed: // Channel closed @@ -142,3 +135,40 @@ func (c *Channel) Close() error { } return nil } + +// channelMessage is the published message with guaranteed delivery. +// message is discarded only when the context is early cancelled. +type channelMessage struct { + context context.Context + message interface{} +} + +func (c channelMessage) publish(publisher chan channelMessage) { + select { + case publisher <- c: + case <-c.context.Done(): + } +} + +func (c channelMessage) publishNonBlocking(publisher chan channelMessage) { + select { + case publisher <- c: + default: // Create another goroutine to keep sending message + go c.publish(publisher) + } +} + +func (c channelMessage) broadcast(subscriber chan interface{}) { + select { + case subscriber <- c.message: + case <-c.context.Done(): + } +} + +func (c channelMessage) broadcastNonBlocking(subscriber chan interface{}) { + select { + case subscriber <- c.message: + default: // Create another goroutine to keep sending message + go c.broadcast(subscriber) + } +} diff --git a/app/stats/channel_test.go b/app/stats/channel_test.go index 6458711b..d32c0c56 100644 --- a/app/stats/channel_test.go +++ b/app/stats/channel_test.go @@ -1,6 +1,7 @@ package stats_test import ( + "context" "fmt" "testing" "time" @@ -12,8 +13,7 @@ import ( func TestStatsChannel(t *testing.T) { // At most 2 subscribers could be registered - c := NewChannel(&ChannelConfig{SubscriberLimit: 2}) - source := c.Channel() + c := NewChannel(&ChannelConfig{SubscriberLimit: 2, Blocking: true}) a, err := stats.SubscribeRunnableChannel(c) common.Must(err) @@ -34,21 +34,12 @@ func TestStatsChannel(t *testing.T) { stopCh := make(chan struct{}) errCh := make(chan string) - go func() { // Blocking publish - source <- 1 - source <- 2 - source <- "3" - source <- []int{4} - source <- nil // Dummy messsage with no subscriber receiving, will block reading goroutine - for i := 0; i < cap(source); i++ { - source <- nil // Fill source channel's buffer - } - select { - case source <- nil: // Source writing should be blocked here, for last message was not cleared and buffer was full - errCh <- fmt.Sprint("unexpected non-blocked source channel") - default: - close(stopCh) - } + go func() { + c.Publish(context.Background(), 1) + c.Publish(context.Background(), 2) + c.Publish(context.Background(), "3") + c.Publish(context.Background(), []int{4}) + stopCh <- struct{}{} }() go func() { @@ -64,6 +55,7 @@ func TestStatsChannel(t *testing.T) { if v, ok := (<-a).([]int); !ok || v[0] != 4 { errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) } + stopCh <- struct{}{} }() go func() { @@ -79,14 +71,18 @@ func TestStatsChannel(t *testing.T) { if v, ok := (<-b).([]int); !ok || v[0] != 4 { errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) } + stopCh <- struct{}{} }() - select { - case <-time.After(2 * time.Second): - t.Fatal("Test timeout after 2s") - case e := <-errCh: - t.Fatal(e) - case <-stopCh: + timeout := time.After(2 * time.Second) + for i := 0; i < 3; i++ { + select { + case <-timeout: + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } } // Test the unsubscription of channel @@ -100,12 +96,10 @@ func TestStatsChannel(t *testing.T) { } func TestStatsChannelUnsubcribe(t *testing.T) { - c := NewChannel(&ChannelConfig{}) + c := NewChannel(&ChannelConfig{Blocking: true}) common.Must(c.Start()) defer c.Close() - source := c.Channel() - a, err := c.Subscribe() common.Must(err) defer c.Unsubscribe(a) @@ -133,9 +127,9 @@ func TestStatsChannelUnsubcribe(t *testing.T) { } go func() { // Blocking publish - source <- 1 + c.Publish(context.Background(), 1) <-pauseCh // Wait for `b` goroutine to resume sending message - source <- 2 + c.Publish(context.Background(), 2) }() go func() { @@ -151,7 +145,7 @@ func TestStatsChannelUnsubcribe(t *testing.T) { if v, ok := (<-b).(int); !ok || v != 1 { errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) } - // Unsubscribe `b` while `source`'s messaging is paused + // Unsubscribe `b` while publishing is paused c.Unsubscribe(b) { // Test `b` is not in subscribers var aSet, bSet bool @@ -167,7 +161,7 @@ func TestStatsChannelUnsubcribe(t *testing.T) { errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) } } - // Resume `source`'s progress + // Resume publishing progress close(pauseCh) // Test `b` is neither closed nor able to receive any data select { @@ -191,78 +185,142 @@ func TestStatsChannelUnsubcribe(t *testing.T) { } } -func TestStatsChannelTimeout(t *testing.T) { +func TestStatsChannelBlocking(t *testing.T) { // Do not use buffer so as to create blocking scenario - c := NewChannel(&ChannelConfig{BufferSize: 0, BroadcastTimeout: 50}) + c := NewChannel(&ChannelConfig{BufferSize: 0, Blocking: true}) common.Must(c.Start()) defer c.Close() - source := c.Channel() - a, err := c.Subscribe() common.Must(err) defer c.Unsubscribe(a) - b, err := c.Subscribe() - common.Must(err) - defer c.Unsubscribe(b) - + pauseCh := make(chan struct{}) stopCh := make(chan struct{}) errCh := make(chan string) - go func() { // Blocking publish - source <- 1 - source <- 2 + ctx, cancel := context.WithCancel(context.Background()) + + // Test blocking channel publishing + go func() { + // Dummy messsage with no subscriber receiving, will block broadcasting goroutine + c.Publish(context.Background(), nil) + + <-pauseCh + + // Publishing should be blocked here, for last message was not cleared and buffer was full + c.Publish(context.Background(), nil) + + pauseCh <- struct{}{} + + // Publishing should still be blocked here + c.Publish(ctx, nil) + + // Check publishing is done because context is canceled + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + errCh <- fmt.Sprint("unexpected error: ", ctx.Err()) + } + default: + errCh <- "unexpected non-blocked publishing" + } + close(stopCh) }() go func() { - if v, ok := (<-a).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + pauseCh <- struct{}{} + + select { + case <-pauseCh: + errCh <- "unexpected non-blocked publishing" + case <-time.After(100 * time.Millisecond): } - if v, ok := (<-a).(int); !ok || v != 2 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + + // Receive first published message + <-a + + select { + case <-pauseCh: + case <-time.After(100 * time.Millisecond): + errCh <- "unexpected blocking publishing" } - { // Test `b` is still in subscribers yet (because `a` receives 2 first) - var aSet, bSet bool - for _, s := range c.Subscribers() { - if s == a { - aSet = true - } - if s == b { - bSet = true - } - } - if !(aSet && bSet) { - errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + + // Manually cancel the context to end publishing + cancel() + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} + +func TestStatsChannelNonBlocking(t *testing.T) { + // Do not use buffer so as to create blocking scenario + c := NewChannel(&ChannelConfig{BufferSize: 0, Blocking: false}) + common.Must(c.Start()) + defer c.Close() + + a, err := c.Subscribe() + common.Must(err) + defer c.Unsubscribe(a) + + pauseCh := make(chan struct{}) + stopCh := make(chan struct{}) + errCh := make(chan string) + + ctx, cancel := context.WithCancel(context.Background()) + + // Test blocking channel publishing + go func() { + c.Publish(context.Background(), nil) + c.Publish(context.Background(), nil) + pauseCh <- struct{}{} + <-pauseCh + c.Publish(ctx, nil) + c.Publish(ctx, nil) + // Check publishing is done because context is canceled + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + errCh <- fmt.Sprint("unexpected error: ", ctx.Err()) } + case <-time.After(100 * time.Millisecond): + errCh <- "unexpected non-cancelled publishing" } }() go func() { - if v, ok := (<-b).(int); !ok || v != 1 { - errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + // Check publishing won't block even if there is no subscriber receiving message + select { + case <-pauseCh: + case <-time.After(100 * time.Millisecond): + errCh <- "unexpected blocking publishing" } - // Block `b` channel for a time longer than `source`'s timeout - <-time.After(200 * time.Millisecond) - { // Test `b` has been unsubscribed by source - var aSet, bSet bool - for _, s := range c.Subscribers() { - if s == a { - aSet = true - } - if s == b { - bSet = true - } - } - if !(aSet && !bSet) { - errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) - } + + // Receive first and second published message + <-a + <-a + + pauseCh <- struct{}{} + + // Manually cancel the context to end publishing + cancel() + + // Check third and forth published message is cancelled and cannot receive + <-time.After(100 * time.Millisecond) + select { + case <-a: + errCh <- "unexpected non-cancelled publishing" + default: } - select { // Test `b` has been closed by source - case v, ok := <-b: - if ok { - errCh <- fmt.Sprint("unexpected data received: ", v) - } + select { + case <-a: + errCh <- "unexpected non-cancelled publishing" default: } close(stopCh) @@ -279,12 +337,10 @@ func TestStatsChannelTimeout(t *testing.T) { func TestStatsChannelConcurrency(t *testing.T) { // Do not use buffer so as to create blocking scenario - c := NewChannel(&ChannelConfig{BufferSize: 0, BroadcastTimeout: 100}) + c := NewChannel(&ChannelConfig{BufferSize: 0, Blocking: true}) common.Must(c.Start()) defer c.Close() - source := c.Channel() - a, err := c.Subscribe() common.Must(err) defer c.Unsubscribe(a) @@ -297,8 +353,8 @@ func TestStatsChannelConcurrency(t *testing.T) { errCh := make(chan string) go func() { // Blocking publish - source <- 1 - source <- 2 + c.Publish(context.Background(), 1) + c.Publish(context.Background(), 2) }() go func() { @@ -311,8 +367,7 @@ func TestStatsChannelConcurrency(t *testing.T) { }() go func() { - // Block `b` for a time shorter than `source`'s timeout - // So as to ensure source channel is trying to send message to `b`. + // Block `b` for a time so as to ensure source channel is trying to send message to `b`. <-time.After(25 * time.Millisecond) // This causes concurrency scenario: unsubscribe `b` while trying to send message to it c.Unsubscribe(b) diff --git a/app/stats/config.pb.go b/app/stats/config.pb.go index f9402fc7..6fd8f1b5 100644 --- a/app/stats/config.pb.go +++ b/app/stats/config.pb.go @@ -68,9 +68,9 @@ type ChannelConfig struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - SubscriberLimit int32 `protobuf:"varint,1,opt,name=SubscriberLimit,proto3" json:"SubscriberLimit,omitempty"` - BufferSize int32 `protobuf:"varint,2,opt,name=BufferSize,proto3" json:"BufferSize,omitempty"` - BroadcastTimeout int32 `protobuf:"varint,3,opt,name=BroadcastTimeout,proto3" json:"BroadcastTimeout,omitempty"` + Blocking bool `protobuf:"varint,1,opt,name=Blocking,proto3" json:"Blocking,omitempty"` + SubscriberLimit int32 `protobuf:"varint,2,opt,name=SubscriberLimit,proto3" json:"SubscriberLimit,omitempty"` + BufferSize int32 `protobuf:"varint,3,opt,name=BufferSize,proto3" json:"BufferSize,omitempty"` } func (x *ChannelConfig) Reset() { @@ -105,6 +105,13 @@ func (*ChannelConfig) Descriptor() ([]byte, []int) { return file_app_stats_config_proto_rawDescGZIP(), []int{1} } +func (x *ChannelConfig) GetBlocking() bool { + if x != nil { + return x.Blocking + } + return false +} + func (x *ChannelConfig) GetSubscriberLimit() int32 { if x != nil { return x.SubscriberLimit @@ -119,34 +126,26 @@ func (x *ChannelConfig) GetBufferSize() int32 { return 0 } -func (x *ChannelConfig) GetBroadcastTimeout() int32 { - if x != nil { - return x.BroadcastTimeout - } - return 0 -} - var File_app_stats_config_proto protoreflect.FileDescriptor var file_app_stats_config_proto_rawDesc = []byte{ 0x0a, 0x16, 0x61, 0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x22, 0x08, - 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x85, 0x01, 0x0a, 0x0d, 0x43, 0x68, 0x61, - 0x6e, 0x6e, 0x65, 0x6c, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x53, 0x75, - 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x0f, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, - 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69, - 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, - 0x53, 0x69, 0x7a, 0x65, 0x12, 0x2a, 0x0a, 0x10, 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, - 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x10, - 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, - 0x42, 0x4d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, - 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x50, 0x01, 0x5a, 0x18, - 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, - 0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, - 0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x73, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x75, 0x0a, 0x0d, 0x43, 0x68, 0x61, 0x6e, + 0x6e, 0x65, 0x6c, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x42, 0x6c, 0x6f, + 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x42, 0x6c, 0x6f, + 0x63, 0x6b, 0x69, 0x6e, 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, + 0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0f, + 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x12, + 0x1e, 0x0a, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69, 0x7a, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69, 0x7a, 0x65, 0x42, + 0x4d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, + 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x50, 0x01, 0x5a, 0x18, 0x76, + 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, + 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e, + 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x73, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/app/stats/config.proto b/app/stats/config.proto index 0ea911fd..f2d0da9c 100644 --- a/app/stats/config.proto +++ b/app/stats/config.proto @@ -11,7 +11,7 @@ message Config { } message ChannelConfig { - int32 SubscriberLimit = 1; - int32 BufferSize = 2; - int32 BroadcastTimeout = 3; + bool Blocking = 1; + int32 SubscriberLimit = 2; + int32 BufferSize = 3; } diff --git a/app/stats/stats.go b/app/stats/stats.go index 8a5a1eb0..0e14faf2 100644 --- a/app/stats/stats.go +++ b/app/stats/stats.go @@ -94,7 +94,7 @@ func (m *Manager) RegisterChannel(name string) (stats.Channel, error) { return nil, newError("Channel ", name, " already registered.") } newError("create new channel ", name).AtDebug().WriteToLog() - c := NewChannel(&ChannelConfig{BufferSize: 16, BroadcastTimeout: 100}) + c := NewChannel(&ChannelConfig{BufferSize: 64, Blocking: false}) m.channels[name] = c if m.running { return c, c.Start() diff --git a/features/stats/stats.go b/features/stats/stats.go index 73fae0f4..f92c75d0 100644 --- a/features/stats/stats.go +++ b/features/stats/stats.go @@ -3,6 +3,8 @@ package stats //go:generate errorgen import ( + "context" + "v2ray.com/core/common" "v2ray.com/core/features" ) @@ -25,8 +27,8 @@ type Counter interface { type Channel interface { // Channel is a runnable unit. common.Runnable - // Publish broadcasts a message through the channel. - Publish(interface{}) + // Publish broadcasts a message through the channel with a controlling context. + Publish(context.Context, interface{}) // SubscriberCount returns the number of the subscribers. Subscribers() []chan interface{} // Subscribe registers for listening to channel stream and returns a new listener channel.