split session manager out of mux client and server

pull/432/head
Darien Raymond 2017-04-12 22:31:11 +02:00
parent 4f4ced6b02
commit ad083989aa
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
4 changed files with 164 additions and 126 deletions

View File

@ -22,42 +22,6 @@ const (
maxTotal = 128 maxTotal = 128
) )
type manager interface {
remove(id uint16)
}
type session struct {
sync.Mutex
input ray.InputStream
output ray.OutputStream
parent manager
id uint16
uplinkClosed bool
downlinkClosed bool
}
func (s *session) closeUplink() {
var allDone bool
s.Lock()
s.uplinkClosed = true
allDone = s.uplinkClosed && s.downlinkClosed
s.Unlock()
if allDone {
s.parent.remove(s.id)
}
}
func (s *session) closeDownlink() {
var allDone bool
s.Lock()
s.downlinkClosed = true
allDone = s.uplinkClosed && s.downlinkClosed
s.Unlock()
if allDone {
s.parent.remove(s.id)
}
}
type ClientManager struct { type ClientManager struct {
access sync.Mutex access sync.Mutex
clients []*Client clients []*Client
@ -112,9 +76,7 @@ func (m *ClientManager) onClientFinish() {
} }
type Client struct { type Client struct {
access sync.RWMutex sessionManager *SessionManager
count uint16
sessions map[uint16]*session
inboundRay ray.InboundRay inboundRay ray.InboundRay
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@ -131,12 +93,11 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
pipe := ray.NewRay(ctx) pipe := ray.NewRay(ctx)
go p.Process(ctx, pipe, dialer) go p.Process(ctx, pipe, dialer)
c := &Client{ c := &Client{
sessions: make(map[uint16]*session, 256), sessionManager: NewSessionManager(),
inboundRay: pipe, inboundRay: pipe,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
manager: m, manager: m,
count: 0,
session2Remove: make(chan uint16, 16), session2Remove: make(chan uint16, 16),
concurrency: m.config.Concurrency, concurrency: m.config.Concurrency,
} }
@ -145,14 +106,6 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
return c, nil return c, nil
} }
func (m *Client) remove(id uint16) {
select {
case m.session2Remove <- id:
default:
// Probably not gonna happen.
}
}
func (m *Client) Closed() bool { func (m *Client) Closed() bool {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@ -168,42 +121,28 @@ func (m *Client) monitor() {
for { for {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
m.cleanup() m.sessionManager.Close()
m.inboundRay.InboundInput().Close()
m.inboundRay.InboundOutput().CloseError()
return return
case id := <-m.session2Remove: case <-time.After(time.Second * 6):
m.access.Lock() size := m.sessionManager.Size()
delete(m.sessions, id) if size == 0 {
if len(m.sessions) == 0 {
m.cancel() m.cancel()
} }
m.access.Unlock()
} }
} }
} }
func (m *Client) cleanup() { func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
m.access.Lock()
defer m.access.Unlock()
m.inboundRay.InboundInput().Close()
m.inboundRay.InboundOutput().CloseError()
for _, s := range m.sessions {
s.closeUplink()
s.closeDownlink()
s.output.CloseError()
}
}
func fetchInput(ctx context.Context, s *session, output buf.Writer) {
dest, _ := proxy.TargetFromContext(ctx) dest, _ := proxy.TargetFromContext(ctx)
writer := &Writer{ writer := &Writer{
dest: dest, dest: dest,
id: s.id, id: s.ID,
writer: output, writer: output,
} }
defer writer.Close() defer writer.Close()
defer s.closeUplink() defer s.CloseUplink()
log.Trace(newError("dispatching request to ", dest)) log.Trace(newError("dispatching request to ", dest))
data, _ := s.input.ReadTimeout(time.Millisecond * 500) data, _ := s.input.ReadTimeout(time.Millisecond * 500)
@ -218,22 +157,9 @@ func fetchInput(ctx context.Context, s *session, output buf.Writer) {
} }
} }
func waitForDone(ctx context.Context, s *session) {
<-ctx.Done()
s.closeUplink()
s.closeDownlink()
s.output.Close()
}
func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool { func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool {
m.access.Lock() numSession := m.sessionManager.Size()
defer m.access.Unlock() if numSession >= int(m.concurrency) || numSession >= maxTotal {
if len(m.sessions) >= int(m.concurrency) {
return false
}
if m.count >= maxTotal {
return false return false
} }
@ -243,17 +169,13 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
default: default:
} }
m.count++ s := &Session{
id := m.count
s := &session{
input: outboundRay.OutboundInput(), input: outboundRay.OutboundInput(),
output: outboundRay.OutboundOutput(), output: outboundRay.OutboundOutput(),
parent: m, parent: m.sessionManager,
id: id,
} }
m.sessions[id] = s m.sessionManager.Allocate(s)
go fetchInput(ctx, s, m.inboundRay.InboundInput()) go fetchInput(ctx, s, m.inboundRay.InboundInput())
go waitForDone(ctx, s)
return true return true
} }
@ -305,11 +227,9 @@ func (m *Client) fetchOutput() {
continue continue
} }
m.access.RLock() s, found := m.sessionManager.Get(meta.SessionID)
s, found := m.sessions[meta.SessionID]
m.access.RUnlock()
if found && meta.SessionStatus == SessionStatusEnd { if found && meta.SessionStatus == SessionStatusEnd {
s.closeDownlink() s.CloseDownlink()
s.output.Close() s.output.Close()
} }
if !meta.Option.Has(OptionData) { if !meta.Option.Has(OptionData) {
@ -354,34 +274,27 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (ray.Inboun
ray := ray.NewRay(ctx) ray := ray.NewRay(ctx)
worker := &ServerWorker{ worker := &ServerWorker{
dispatcher: s.dispatcher, dispatcher: s.dispatcher,
outboundRay: ray, outboundRay: ray,
sessions: make(map[uint16]*session), sessionManager: NewSessionManager(),
} }
go worker.run(ctx) go worker.run(ctx)
return ray, nil return ray, nil
} }
type ServerWorker struct { type ServerWorker struct {
dispatcher dispatcher.Interface dispatcher dispatcher.Interface
outboundRay ray.OutboundRay outboundRay ray.OutboundRay
sessions map[uint16]*session sessionManager *SessionManager
access sync.RWMutex
} }
func (w *ServerWorker) remove(id uint16) { func handle(ctx context.Context, s *Session, output buf.Writer) {
w.access.Lock() writer := NewResponseWriter(s.ID, output)
delete(w.sessions, id)
w.access.Unlock()
}
func handle(ctx context.Context, s *session, output buf.Writer) {
writer := NewResponseWriter(s.id, output)
if err := buf.PipeUntilEOF(signal.BackgroundTimer(), s.input, writer); err != nil { if err := buf.PipeUntilEOF(signal.BackgroundTimer(), s.input, writer); err != nil {
log.Trace(newError("session ", s.id, " ends: ").Base(err)) log.Trace(newError("session ", s.ID, " ends: ").Base(err))
} }
writer.Close() writer.Close()
s.closeDownlink() s.CloseDownlink()
} }
func (w *ServerWorker) run(ctx context.Context) { func (w *ServerWorker) run(ctx context.Context) {
@ -410,12 +323,9 @@ func (w *ServerWorker) run(ctx context.Context) {
continue continue
} }
w.access.RLock() s, found := w.sessionManager.Get(meta.SessionID)
s, found := w.sessions[meta.SessionID]
w.access.RUnlock()
if found && meta.SessionStatus == SessionStatusEnd { if found && meta.SessionStatus == SessionStatusEnd {
s.closeUplink() s.CloseUplink()
s.output.Close() s.output.Close()
} }
@ -426,15 +336,13 @@ func (w *ServerWorker) run(ctx context.Context) {
log.Trace(newError("failed to dispatch request.").Base(err)) log.Trace(newError("failed to dispatch request.").Base(err))
continue continue
} }
s = &session{ s = &Session{
input: inboundRay.InboundOutput(), input: inboundRay.InboundOutput(),
output: inboundRay.InboundInput(), output: inboundRay.InboundInput(),
parent: w, parent: w.sessionManager,
id: meta.SessionID, ID: meta.SessionID,
} }
w.access.Lock() w.sessionManager.Add(s)
w.sessions[meta.SessionID] = s
w.access.Unlock()
go handle(ctx, s, w.outboundRay.OutboundOutput()) go handle(ctx, s, w.outboundRay.OutboundOutput())
} }

View File

@ -0,0 +1,99 @@
package mux
import (
"sync"
"v2ray.com/core/transport/ray"
)
type SessionManager struct {
sync.RWMutex
count uint16
sessions map[uint16]*Session
}
func NewSessionManager() *SessionManager {
return &SessionManager{
count: 0,
sessions: make(map[uint16]*Session, 32),
}
}
func (m *SessionManager) Size() int {
m.RLock()
defer m.RUnlock()
return len(m.sessions)
}
func (m *SessionManager) Allocate(s *Session) {
m.Lock()
defer m.Unlock()
m.count++
s.ID = m.count
m.sessions[s.ID] = s
}
func (m *SessionManager) Add(s *Session) {
m.Lock()
defer m.Unlock()
m.sessions[s.ID] = s
}
func (m *SessionManager) Remove(id uint16) {
m.Lock()
defer m.Unlock()
delete(m.sessions, id)
}
func (m *SessionManager) Get(id uint16) (*Session, bool) {
m.RLock()
defer m.RUnlock()
s, found := m.sessions[id]
return s, found
}
func (m *SessionManager) Close() {
m.RLock()
defer m.RUnlock()
for _, s := range m.sessions {
s.output.CloseError()
}
}
type Session struct {
sync.Mutex
input ray.InputStream
output ray.OutputStream
parent *SessionManager
ID uint16
uplinkClosed bool
downlinkClosed bool
}
func (s *Session) CloseUplink() {
var allDone bool
s.Lock()
s.uplinkClosed = true
allDone = s.uplinkClosed && s.downlinkClosed
s.Unlock()
if allDone {
s.parent.Remove(s.ID)
}
}
func (s *Session) CloseDownlink() {
var allDone bool
s.Lock()
s.downlinkClosed = true
allDone = s.uplinkClosed && s.downlinkClosed
s.Unlock()
if allDone {
s.parent.Remove(s.ID)
}
}

View File

@ -0,0 +1,28 @@
package mux_test
import (
"testing"
. "v2ray.com/core/app/proxyman/mux"
"v2ray.com/core/testing/assert"
)
func TestSessionManagerAdd(t *testing.T) {
assert := assert.On(t)
m := NewSessionManager()
s := &Session{}
m.Allocate(s)
assert.Uint16(s.ID).Equals(1)
s = &Session{}
m.Allocate(s)
assert.Uint16(s.ID).Equals(2)
s = &Session{
ID: 4,
}
m.Add(s)
assert.Uint16(s.ID).Equals(4)
}

View File

@ -0,0 +1,3 @@
package mux
type statusHandler func(meta *FrameMetadata) error