diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index 61169213..65c8debe 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -22,42 +22,6 @@ const ( maxTotal = 128 ) -type manager interface { - remove(id uint16) -} - -type session struct { - sync.Mutex - input ray.InputStream - output ray.OutputStream - parent manager - id uint16 - uplinkClosed bool - downlinkClosed bool -} - -func (s *session) closeUplink() { - var allDone bool - s.Lock() - s.uplinkClosed = true - allDone = s.uplinkClosed && s.downlinkClosed - s.Unlock() - if allDone { - s.parent.remove(s.id) - } -} - -func (s *session) closeDownlink() { - var allDone bool - s.Lock() - s.downlinkClosed = true - allDone = s.uplinkClosed && s.downlinkClosed - s.Unlock() - if allDone { - s.parent.remove(s.id) - } -} - type ClientManager struct { access sync.Mutex clients []*Client @@ -112,9 +76,7 @@ func (m *ClientManager) onClientFinish() { } type Client struct { - access sync.RWMutex - count uint16 - sessions map[uint16]*session + sessionManager *SessionManager inboundRay ray.InboundRay ctx context.Context cancel context.CancelFunc @@ -131,12 +93,11 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client pipe := ray.NewRay(ctx) go p.Process(ctx, pipe, dialer) c := &Client{ - sessions: make(map[uint16]*session, 256), + sessionManager: NewSessionManager(), inboundRay: pipe, ctx: ctx, cancel: cancel, manager: m, - count: 0, session2Remove: make(chan uint16, 16), concurrency: m.config.Concurrency, } @@ -145,14 +106,6 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client return c, nil } -func (m *Client) remove(id uint16) { - select { - case m.session2Remove <- id: - default: - // Probably not gonna happen. - } -} - func (m *Client) Closed() bool { select { case <-m.ctx.Done(): @@ -168,42 +121,28 @@ func (m *Client) monitor() { for { select { case <-m.ctx.Done(): - m.cleanup() + m.sessionManager.Close() + m.inboundRay.InboundInput().Close() + m.inboundRay.InboundOutput().CloseError() return - case id := <-m.session2Remove: - m.access.Lock() - delete(m.sessions, id) - if len(m.sessions) == 0 { + case <-time.After(time.Second * 6): + size := m.sessionManager.Size() + if size == 0 { m.cancel() } - m.access.Unlock() } } } -func (m *Client) cleanup() { - m.access.Lock() - defer m.access.Unlock() - - m.inboundRay.InboundInput().Close() - m.inboundRay.InboundOutput().CloseError() - - for _, s := range m.sessions { - s.closeUplink() - s.closeDownlink() - s.output.CloseError() - } -} - -func fetchInput(ctx context.Context, s *session, output buf.Writer) { +func fetchInput(ctx context.Context, s *Session, output buf.Writer) { dest, _ := proxy.TargetFromContext(ctx) writer := &Writer{ dest: dest, - id: s.id, + id: s.ID, writer: output, } defer writer.Close() - defer s.closeUplink() + defer s.CloseUplink() log.Trace(newError("dispatching request to ", dest)) data, _ := s.input.ReadTimeout(time.Millisecond * 500) @@ -218,22 +157,9 @@ func fetchInput(ctx context.Context, s *session, output buf.Writer) { } } -func waitForDone(ctx context.Context, s *session) { - <-ctx.Done() - s.closeUplink() - s.closeDownlink() - s.output.Close() -} - func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool { - m.access.Lock() - defer m.access.Unlock() - - if len(m.sessions) >= int(m.concurrency) { - return false - } - - if m.count >= maxTotal { + numSession := m.sessionManager.Size() + if numSession >= int(m.concurrency) || numSession >= maxTotal { return false } @@ -243,17 +169,13 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool default: } - m.count++ - id := m.count - s := &session{ + s := &Session{ input: outboundRay.OutboundInput(), output: outboundRay.OutboundOutput(), - parent: m, - id: id, + parent: m.sessionManager, } - m.sessions[id] = s + m.sessionManager.Allocate(s) go fetchInput(ctx, s, m.inboundRay.InboundInput()) - go waitForDone(ctx, s) return true } @@ -305,11 +227,9 @@ func (m *Client) fetchOutput() { continue } - m.access.RLock() - s, found := m.sessions[meta.SessionID] - m.access.RUnlock() + s, found := m.sessionManager.Get(meta.SessionID) if found && meta.SessionStatus == SessionStatusEnd { - s.closeDownlink() + s.CloseDownlink() s.output.Close() } if !meta.Option.Has(OptionData) { @@ -354,34 +274,27 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (ray.Inboun ray := ray.NewRay(ctx) worker := &ServerWorker{ - dispatcher: s.dispatcher, - outboundRay: ray, - sessions: make(map[uint16]*session), + dispatcher: s.dispatcher, + outboundRay: ray, + sessionManager: NewSessionManager(), } go worker.run(ctx) return ray, nil } type ServerWorker struct { - dispatcher dispatcher.Interface - outboundRay ray.OutboundRay - sessions map[uint16]*session - access sync.RWMutex + dispatcher dispatcher.Interface + outboundRay ray.OutboundRay + sessionManager *SessionManager } -func (w *ServerWorker) remove(id uint16) { - w.access.Lock() - delete(w.sessions, id) - w.access.Unlock() -} - -func handle(ctx context.Context, s *session, output buf.Writer) { - writer := NewResponseWriter(s.id, output) +func handle(ctx context.Context, s *Session, output buf.Writer) { + writer := NewResponseWriter(s.ID, output) if err := buf.PipeUntilEOF(signal.BackgroundTimer(), s.input, writer); err != nil { - log.Trace(newError("session ", s.id, " ends: ").Base(err)) + log.Trace(newError("session ", s.ID, " ends: ").Base(err)) } writer.Close() - s.closeDownlink() + s.CloseDownlink() } func (w *ServerWorker) run(ctx context.Context) { @@ -410,12 +323,9 @@ func (w *ServerWorker) run(ctx context.Context) { continue } - w.access.RLock() - s, found := w.sessions[meta.SessionID] - w.access.RUnlock() - + s, found := w.sessionManager.Get(meta.SessionID) if found && meta.SessionStatus == SessionStatusEnd { - s.closeUplink() + s.CloseUplink() s.output.Close() } @@ -426,15 +336,13 @@ func (w *ServerWorker) run(ctx context.Context) { log.Trace(newError("failed to dispatch request.").Base(err)) continue } - s = &session{ + s = &Session{ input: inboundRay.InboundOutput(), output: inboundRay.InboundInput(), - parent: w, - id: meta.SessionID, + parent: w.sessionManager, + ID: meta.SessionID, } - w.access.Lock() - w.sessions[meta.SessionID] = s - w.access.Unlock() + w.sessionManager.Add(s) go handle(ctx, s, w.outboundRay.OutboundOutput()) } diff --git a/app/proxyman/mux/session.go b/app/proxyman/mux/session.go new file mode 100644 index 00000000..33e9efee --- /dev/null +++ b/app/proxyman/mux/session.go @@ -0,0 +1,99 @@ +package mux + +import ( + "sync" + + "v2ray.com/core/transport/ray" +) + +type SessionManager struct { + sync.RWMutex + count uint16 + sessions map[uint16]*Session +} + +func NewSessionManager() *SessionManager { + return &SessionManager{ + count: 0, + sessions: make(map[uint16]*Session, 32), + } +} + +func (m *SessionManager) Size() int { + m.RLock() + defer m.RUnlock() + + return len(m.sessions) +} + +func (m *SessionManager) Allocate(s *Session) { + m.Lock() + defer m.Unlock() + + m.count++ + s.ID = m.count + m.sessions[s.ID] = s +} + +func (m *SessionManager) Add(s *Session) { + m.Lock() + defer m.Unlock() + + m.sessions[s.ID] = s +} + +func (m *SessionManager) Remove(id uint16) { + m.Lock() + defer m.Unlock() + + delete(m.sessions, id) +} + +func (m *SessionManager) Get(id uint16) (*Session, bool) { + m.RLock() + defer m.RUnlock() + + s, found := m.sessions[id] + return s, found +} + +func (m *SessionManager) Close() { + m.RLock() + defer m.RUnlock() + + for _, s := range m.sessions { + s.output.CloseError() + } +} + +type Session struct { + sync.Mutex + input ray.InputStream + output ray.OutputStream + parent *SessionManager + ID uint16 + uplinkClosed bool + downlinkClosed bool +} + +func (s *Session) CloseUplink() { + var allDone bool + s.Lock() + s.uplinkClosed = true + allDone = s.uplinkClosed && s.downlinkClosed + s.Unlock() + if allDone { + s.parent.Remove(s.ID) + } +} + +func (s *Session) CloseDownlink() { + var allDone bool + s.Lock() + s.downlinkClosed = true + allDone = s.uplinkClosed && s.downlinkClosed + s.Unlock() + if allDone { + s.parent.Remove(s.ID) + } +} diff --git a/app/proxyman/mux/session_test.go b/app/proxyman/mux/session_test.go new file mode 100644 index 00000000..e57de2f5 --- /dev/null +++ b/app/proxyman/mux/session_test.go @@ -0,0 +1,28 @@ +package mux_test + +import ( + "testing" + + . "v2ray.com/core/app/proxyman/mux" + "v2ray.com/core/testing/assert" +) + +func TestSessionManagerAdd(t *testing.T) { + assert := assert.On(t) + + m := NewSessionManager() + + s := &Session{} + m.Allocate(s) + assert.Uint16(s.ID).Equals(1) + + s = &Session{} + m.Allocate(s) + assert.Uint16(s.ID).Equals(2) + + s = &Session{ + ID: 4, + } + m.Add(s) + assert.Uint16(s.ID).Equals(4) +} diff --git a/app/proxyman/mux/status.go b/app/proxyman/mux/status.go new file mode 100644 index 00000000..84d2da28 --- /dev/null +++ b/app/proxyman/mux/status.go @@ -0,0 +1,3 @@ +package mux + +type statusHandler func(meta *FrameMetadata) error