per inbound session history

pull/432/head
Darien Raymond 2017-02-12 16:53:23 +01:00
parent 10d26f2d7f
commit ec95caa946
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
3 changed files with 28 additions and 21 deletions

View File

@ -45,10 +45,12 @@ func TestRequestSerialization(t *testing.T) {
buffer2.Append(buffer.Bytes()) buffer2.Append(buffer.Bytes())
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sessionHistory := NewSessionHistory(ctx)
userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash) userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash)
userValidator.Add(user) userValidator.Add(user)
server := NewServerSession(userValidator) server := NewServerSession(userValidator, sessionHistory)
actualRequest, err := server.DecodeRequestHeader(buffer) actualRequest, err := server.DecodeRequestHeader(buffer)
assert.Error(err).IsNil() assert.Error(err).IsNil()

View File

@ -1,6 +1,7 @@
package encoding package encoding
import ( import (
"context"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/md5" "crypto/md5"
@ -25,26 +26,26 @@ type sessionId struct {
nonce [16]byte nonce [16]byte
} }
type sessionHistory struct { type SessionHistory struct {
sync.RWMutex sync.RWMutex
cache map[sessionId]time.Time cache map[sessionId]time.Time
} }
func newSessionHistory() *sessionHistory { func NewSessionHistory(ctx context.Context) *SessionHistory {
h := &sessionHistory{ h := &SessionHistory{
cache: make(map[sessionId]time.Time, 128), cache: make(map[sessionId]time.Time, 128),
} }
go h.run() go h.run(ctx)
return h return h
} }
func (h *sessionHistory) Add(session sessionId) { func (h *SessionHistory) add(session sessionId) {
h.Lock() h.Lock()
h.cache[session] = time.Now().Add(time.Minute * 3) h.cache[session] = time.Now().Add(time.Minute * 3)
h.Unlock() h.Unlock()
} }
func (h *sessionHistory) Has(session sessionId) bool { func (h *SessionHistory) has(session sessionId) bool {
h.RLock() h.RLock()
defer h.RUnlock() defer h.RUnlock()
@ -54,9 +55,13 @@ func (h *sessionHistory) Has(session sessionId) bool {
return false return false
} }
func (h *sessionHistory) run() { func (h *SessionHistory) run(ctx context.Context) {
for { for {
time.Sleep(time.Second * 30) select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
}
session2Remove := make([]sessionId, 0, 16) session2Remove := make([]sessionId, 0, 16)
now := time.Now() now := time.Now()
h.Lock() h.Lock()
@ -72,12 +77,9 @@ func (h *sessionHistory) run() {
} }
} }
var (
globalSessionHistory = newSessionHistory()
)
type ServerSession struct { type ServerSession struct {
userValidator protocol.UserValidator userValidator protocol.UserValidator
sessionHistory *SessionHistory
requestBodyKey []byte requestBodyKey []byte
requestBodyIV []byte requestBodyIV []byte
responseBodyKey []byte responseBodyKey []byte
@ -88,9 +90,10 @@ type ServerSession struct {
// NewServerSession creates a new ServerSession, using the given UserValidator. // NewServerSession creates a new ServerSession, using the given UserValidator.
// The ServerSession instance doesn't take ownership of the validator. // The ServerSession instance doesn't take ownership of the validator.
func NewServerSession(validator protocol.UserValidator) *ServerSession { func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionHistory) *ServerSession {
return &ServerSession{ return &ServerSession{
userValidator: validator, userValidator: validator,
sessionHistory: sessionHistory,
} }
} }
@ -140,10 +143,10 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
copy(sid.user[:], vmessAccount.ID.Bytes()) copy(sid.user[:], vmessAccount.ID.Bytes())
copy(sid.key[:], v.requestBodyKey) copy(sid.key[:], v.requestBodyKey)
copy(sid.nonce[:], v.requestBodyIV) copy(sid.nonce[:], v.requestBodyIV)
if globalSessionHistory.Has(sid) { if v.sessionHistory.has(sid) {
return nil, errors.New("VMess|Server: Duplicated session id. Possibly under reply attack.") return nil, errors.New("VMess|Server: Duplicated session id. Possibly under reply attack.")
} }
globalSessionHistory.Add(sid) v.sessionHistory.add(sid)
v.responseHeader = buffer[33] // 1 byte v.responseHeader = buffer[33] // 1 byte
request.Option = protocol.RequestOption(buffer[34]) // 1 byte request.Option = protocol.RequestOption(buffer[34]) // 1 byte

View File

@ -78,6 +78,7 @@ type VMessInboundHandler struct {
clients protocol.UserValidator clients protocol.UserValidator
usersByEmail *userByEmail usersByEmail *userByEmail
detours *DetourConfig detours *DetourConfig
sessionHistory *encoding.SessionHistory
} }
func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) { func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
@ -92,9 +93,10 @@ func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
} }
handler := &VMessInboundHandler{ handler := &VMessInboundHandler{
clients: allowedClients, clients: allowedClients,
detours: config.Detour, detours: config.Detour,
usersByEmail: NewUserByEmail(config.User, config.GetDefaultValue()), usersByEmail: NewUserByEmail(config.User, config.GetDefaultValue()),
sessionHistory: encoding.NewSessionHistory(ctx),
} }
space.OnInitialize(func() error { space.OnInitialize(func() error {
@ -171,7 +173,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
connection.SetReadDeadline(time.Now().Add(time.Second * 8)) connection.SetReadDeadline(time.Now().Add(time.Second * 8))
reader := bufio.NewReader(connection) reader := bufio.NewReader(connection)
session := encoding.NewServerSession(v.clients) session := encoding.NewServerSession(v.clients, v.sessionHistory)
request, err := session.DecodeRequestHeader(reader) request, err := session.DecodeRequestHeader(reader)
if err != nil { if err != nil {