From ec95caa9468e3c82c0441d2cf99d89e29f916778 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 12 Feb 2017 16:53:23 +0100 Subject: [PATCH] per inbound session history --- proxy/vmess/encoding/encoding_test.go | 4 ++- proxy/vmess/encoding/server.go | 35 +++++++++++++++------------ proxy/vmess/inbound/inbound.go | 10 +++++--- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index ea2e999c..83973e4c 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -45,10 +45,12 @@ func TestRequestSerialization(t *testing.T) { buffer2.Append(buffer.Bytes()) ctx, cancel := context.WithCancel(context.Background()) + sessionHistory := NewSessionHistory(ctx) + userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash) userValidator.Add(user) - server := NewServerSession(userValidator) + server := NewServerSession(userValidator, sessionHistory) actualRequest, err := server.DecodeRequestHeader(buffer) assert.Error(err).IsNil() diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index d968da3c..971fc734 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -1,6 +1,7 @@ package encoding import ( + "context" "crypto/aes" "crypto/cipher" "crypto/md5" @@ -25,26 +26,26 @@ type sessionId struct { nonce [16]byte } -type sessionHistory struct { +type SessionHistory struct { sync.RWMutex cache map[sessionId]time.Time } -func newSessionHistory() *sessionHistory { - h := &sessionHistory{ +func NewSessionHistory(ctx context.Context) *SessionHistory { + h := &SessionHistory{ cache: make(map[sessionId]time.Time, 128), } - go h.run() + go h.run(ctx) return h } -func (h *sessionHistory) Add(session sessionId) { +func (h *SessionHistory) add(session sessionId) { h.Lock() h.cache[session] = time.Now().Add(time.Minute * 3) h.Unlock() } -func (h *sessionHistory) Has(session sessionId) bool { +func (h *SessionHistory) has(session sessionId) bool { h.RLock() defer h.RUnlock() @@ -54,9 +55,13 @@ func (h *sessionHistory) Has(session sessionId) bool { return false } -func (h *sessionHistory) run() { +func (h *SessionHistory) run(ctx context.Context) { for { - time.Sleep(time.Second * 30) + select { + case <-ctx.Done(): + return + case <-time.After(time.Second * 30): + } session2Remove := make([]sessionId, 0, 16) now := time.Now() h.Lock() @@ -72,12 +77,9 @@ func (h *sessionHistory) run() { } } -var ( - globalSessionHistory = newSessionHistory() -) - type ServerSession struct { userValidator protocol.UserValidator + sessionHistory *SessionHistory requestBodyKey []byte requestBodyIV []byte responseBodyKey []byte @@ -88,9 +90,10 @@ type ServerSession struct { // NewServerSession creates a new ServerSession, using the given UserValidator. // The ServerSession instance doesn't take ownership of the validator. -func NewServerSession(validator protocol.UserValidator) *ServerSession { +func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionHistory) *ServerSession { return &ServerSession{ - userValidator: validator, + userValidator: validator, + sessionHistory: sessionHistory, } } @@ -140,10 +143,10 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request copy(sid.user[:], vmessAccount.ID.Bytes()) copy(sid.key[:], v.requestBodyKey) copy(sid.nonce[:], v.requestBodyIV) - if globalSessionHistory.Has(sid) { + if v.sessionHistory.has(sid) { return nil, errors.New("VMess|Server: Duplicated session id. Possibly under reply attack.") } - globalSessionHistory.Add(sid) + v.sessionHistory.add(sid) v.responseHeader = buffer[33] // 1 byte request.Option = protocol.RequestOption(buffer[34]) // 1 byte diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 1d3d202e..ac8337eb 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -78,6 +78,7 @@ type VMessInboundHandler struct { clients protocol.UserValidator usersByEmail *userByEmail detours *DetourConfig + sessionHistory *encoding.SessionHistory } func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) { @@ -92,9 +93,10 @@ func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) { } handler := &VMessInboundHandler{ - clients: allowedClients, - detours: config.Detour, - usersByEmail: NewUserByEmail(config.User, config.GetDefaultValue()), + clients: allowedClients, + detours: config.Detour, + usersByEmail: NewUserByEmail(config.User, config.GetDefaultValue()), + sessionHistory: encoding.NewSessionHistory(ctx), } space.OnInitialize(func() error { @@ -171,7 +173,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network, connection.SetReadDeadline(time.Now().Add(time.Second * 8)) reader := bufio.NewReader(connection) - session := encoding.NewServerSession(v.clients) + session := encoding.NewServerSession(v.clients, v.sessionHistory) request, err := session.DecodeRequestHeader(reader) if err != nil {