Move userset to protocol

pull/97/head
v2ray 9 years ago
parent 2d82bb8d4d
commit 791ac307a2

@ -0,0 +1,123 @@
package protocol
import (
"hash"
"sync"
"time"
)
const (
updateIntervalSec = 10
cacheDurationSec = 120
)
type IDHash func(key []byte) hash.Hash
type idEntry struct {
id *ID
userIdx int
lastSec Timestamp
lastSecRemoval Timestamp
}
type UserValidator interface {
Add(user *User) error
Get(timeHash []byte) (*User, Timestamp, bool)
}
type TimedUserValidator struct {
validUsers []*User
userHash map[[16]byte]*indexTimePair
ids []*idEntry
access sync.RWMutex
hasher IDHash
}
type indexTimePair struct {
index int
timeSec Timestamp
}
func NewTimedUserValidator(hasher IDHash) UserValidator {
tus := &TimedUserValidator{
validUsers: make([]*User, 0, 16),
userHash: make(map[[16]byte]*indexTimePair, 512),
access: sync.RWMutex{},
ids: make([]*idEntry, 0, 512),
hasher: hasher,
}
go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
return tus
}
func (this *TimedUserValidator) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
var hashValue [16]byte
var hashValueRemoval [16]byte
idHash := this.hasher(entry.id.Bytes())
for entry.lastSec <= nowSec {
idHash.Write(entry.lastSec.Bytes())
idHash.Sum(hashValue[:0])
idHash.Reset()
idHash.Write(entry.lastSecRemoval.Bytes())
idHash.Sum(hashValueRemoval[:0])
idHash.Reset()
this.access.Lock()
this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
delete(this.userHash, hashValueRemoval)
this.access.Unlock()
entry.lastSec++
entry.lastSecRemoval++
}
}
func (this *TimedUserValidator) updateUserHash(tick <-chan time.Time) {
for now := range tick {
nowSec := Timestamp(now.Unix() + cacheDurationSec)
for _, entry := range this.ids {
this.generateNewHashes(nowSec, entry.userIdx, entry)
}
}
}
func (this *TimedUserValidator) Add(user *User) error {
idx := len(this.validUsers)
this.validUsers = append(this.validUsers, user)
nowSec := time.Now().Unix()
entry := &idEntry{
id: user.ID,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
this.ids = append(this.ids, entry)
for _, alterid := range user.AlterIDs {
entry := &idEntry{
id: alterid,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
this.ids = append(this.ids, entry)
}
return nil
}
func (this *TimedUserValidator) Get(userHash []byte) (*User, Timestamp, bool) {
defer this.access.RUnlock()
this.access.RLock()
var fixedSizeHash [16]byte
copy(fixedSizeHash[:], userHash)
pair, found := this.userHash[fixedSizeHash]
if found {
return this.validUsers[pair.index], pair.timeSec, true
}
return nil, 0, false
}

@ -66,7 +66,7 @@ type VMessInboundHandler struct {
sync.Mutex
packetDispatcher dispatcher.PacketDispatcher
inboundHandlerManager proxyman.InboundHandlerManager
clients protocol.UserSet
clients proto.UserValidator
usersByEmail *userByEmail
accepting bool
listener *hub.TCPHub
@ -91,7 +91,7 @@ func (this *VMessInboundHandler) Close() {
func (this *VMessInboundHandler) GetUser(email string) *proto.User {
user, existing := this.usersByEmail.Get(email)
if !existing {
this.clients.AddUser(user)
this.clients.Add(user)
}
return user
}
@ -211,9 +211,9 @@ func init() {
}
config := rawConfig.(*Config)
allowedClients := protocol.NewTimedUserSet()
allowedClients := proto.NewTimedUserValidator(protocol.IDHash)
for _, user := range config.AllowedUsers {
allowedClients.AddUser(user)
allowedClients.Add(user)
}
handler := &VMessInboundHandler{

@ -14,6 +14,7 @@ import (
v2io "github.com/v2ray/v2ray-core/common/io"
"github.com/v2ray/v2ray-core/common/log"
v2net "github.com/v2ray/v2ray-core/common/net"
proto "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/proxy"
"github.com/v2ray/v2ray-core/proxy/internal"
vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io"
@ -106,7 +107,7 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol
buffer := alloc.NewBuffer().Clear()
defer buffer.Release()
buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(protocol.Timestamp(time.Now().Unix()), 30), buffer)
buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), buffer)
if err != nil {
log.Error("VMessOut: Failed to serialize VMess request: ", err)
return

@ -2,25 +2,27 @@ package protocol
import (
"math/rand"
"github.com/v2ray/v2ray-core/common/protocol"
)
type RandomTimestampGenerator interface {
Next() Timestamp
Next() protocol.Timestamp
}
type RealRandomTimestampGenerator struct {
base Timestamp
base protocol.Timestamp
delta int
}
func NewRandomTimestampGenerator(base Timestamp, delta int) RandomTimestampGenerator {
func NewRandomTimestampGenerator(base protocol.Timestamp, delta int) RandomTimestampGenerator {
return &RealRandomTimestampGenerator{
base: base,
delta: delta,
}
}
func (this *RealRandomTimestampGenerator) Next() Timestamp {
func (this *RealRandomTimestampGenerator) Next() protocol.Timestamp {
rangeInDelta := rand.Intn(this.delta*2) - this.delta
return this.base + Timestamp(rangeInDelta)
return this.base + protocol.Timestamp(rangeInDelta)
}

@ -4,6 +4,7 @@ import (
"testing"
"time"
"github.com/v2ray/v2ray-core/common/protocol"
. "github.com/v2ray/v2ray-core/proxy/vmess/protocol"
v2testing "github.com/v2ray/v2ray-core/testing"
"github.com/v2ray/v2ray-core/testing/assert"
@ -14,7 +15,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
base := time.Now().Unix()
delta := 100
generator := NewRandomTimestampGenerator(Timestamp(base), delta)
generator := NewRandomTimestampGenerator(protocol.Timestamp(base), delta)
for i := 0; i < 100; i++ {
v := int64(generator.Next())

@ -1,22 +1,21 @@
package mocks
import (
proto "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
"github.com/v2ray/v2ray-core/common/protocol"
)
type MockUserSet struct {
Users []*proto.User
Users []*protocol.User
UserHashes map[string]int
Timestamps map[string]protocol.Timestamp
}
func (us *MockUserSet) AddUser(user *proto.User) error {
func (us *MockUserSet) Add(user *protocol.User) error {
us.Users = append(us.Users, user)
return nil
}
func (us *MockUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) {
func (us *MockUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
idx, found := us.UserHashes[string(userhash)]
if found {
return us.Users[idx], us.Timestamps[string(userhash)], true

@ -1,21 +1,20 @@
package mocks
import (
proto "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/common/uuid"
"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
)
type StaticUserSet struct {
}
func (us *StaticUserSet) AddUser(user *proto.User) error {
func (us *StaticUserSet) Add(user *protocol.User) error {
return nil
}
func (us *StaticUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) {
func (us *StaticUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21")
return &proto.User{
ID: proto.NewID(id),
return &protocol.User{
ID: protocol.NewID(id),
}, 0, true
}

@ -1,137 +0,0 @@
package protocol
import (
"sync"
"time"
proto "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/common/serial"
)
const (
updateIntervalSec = 10
cacheDurationSec = 120
)
type Timestamp int64
func (this Timestamp) Bytes() []byte {
return serial.Int64Literal(this).Bytes()
}
func (this Timestamp) HashBytes() []byte {
once := this.Bytes()
bytes := make([]byte, 0, 32)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
return bytes
}
type idEntry struct {
id *proto.ID
userIdx int
lastSec Timestamp
lastSecRemoval Timestamp
}
type UserSet interface {
AddUser(user *proto.User) error
GetUser(timeHash []byte) (*proto.User, Timestamp, bool)
}
type TimedUserSet struct {
validUsers []*proto.User
userHash map[[16]byte]*indexTimePair
ids []*idEntry
access sync.RWMutex
}
type indexTimePair struct {
index int
timeSec Timestamp
}
func NewTimedUserSet() UserSet {
tus := &TimedUserSet{
validUsers: make([]*proto.User, 0, 16),
userHash: make(map[[16]byte]*indexTimePair, 512),
access: sync.RWMutex{},
ids: make([]*idEntry, 0, 512),
}
go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
return tus
}
func (us *TimedUserSet) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
var hashValue [16]byte
var hashValueRemoval [16]byte
idHash := IDHash(entry.id.Bytes())
for entry.lastSec <= nowSec {
idHash.Write(entry.lastSec.Bytes())
idHash.Sum(hashValue[:0])
idHash.Reset()
idHash.Write(entry.lastSecRemoval.Bytes())
idHash.Sum(hashValueRemoval[:0])
idHash.Reset()
us.access.Lock()
us.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
delete(us.userHash, hashValueRemoval)
us.access.Unlock()
entry.lastSec++
entry.lastSecRemoval++
}
}
func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
for now := range tick {
nowSec := Timestamp(now.Unix() + cacheDurationSec)
for _, entry := range us.ids {
us.generateNewHashes(nowSec, entry.userIdx, entry)
}
}
}
func (us *TimedUserSet) AddUser(user *proto.User) error {
idx := len(us.validUsers)
us.validUsers = append(us.validUsers, user)
nowSec := time.Now().Unix()
entry := &idEntry{
id: user.ID,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
us.ids = append(us.ids, entry)
for _, alterid := range user.AlterIDs {
entry := &idEntry{
id: alterid,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
us.ids = append(us.ids, entry)
}
return nil
}
func (us *TimedUserSet) GetUser(userHash []byte) (*proto.User, Timestamp, bool) {
defer us.access.RUnlock()
us.access.RLock()
var fixedSizeHash [16]byte
copy(fixedSizeHash[:], userHash)
pair, found := us.userHash[fixedSizeHash]
if found {
return us.validUsers[pair.index], pair.timeSec, true
}
return nil, 0, false
}

@ -31,6 +31,16 @@ const (
blockSize = 16
)
func hashTimestamp(t proto.Timestamp) []byte {
once := t.Bytes()
bytes := make([]byte, 0, 32)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
return bytes
}
// VMessRequest implements the request message of VMess protocol. It only contains the header of a
// request message. The data part will be handled by connection handler directly, in favor of data
// streaming.
@ -61,11 +71,11 @@ func (this *VMessRequest) IsChunkStream() bool {
// VMessRequestReader is a parser to read VMessRequest from a byte stream.
type VMessRequestReader struct {
vUserSet UserSet
vUserSet proto.UserValidator
}
// NewVMessRequestReader creates a new VMessRequestReader with a given UserSet
func NewVMessRequestReader(vUserSet UserSet) *VMessRequestReader {
func NewVMessRequestReader(vUserSet proto.UserValidator) *VMessRequestReader {
return &VMessRequestReader{
vUserSet: vUserSet,
}
@ -82,13 +92,13 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
return nil, err
}
userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes])
userObj, timeSec, valid := this.vUserSet.Get(buffer.Value[:nBytes])
if !valid {
return nil, proxy.ErrorInvalidAuthentication
}
timestampHash := TimestampHash()
timestampHash.Write(timeSec.HashBytes())
timestampHash.Write(hashTimestamp(timeSec))
iv := timestampHash.Sum(nil)
aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv)
if err != nil {
@ -223,7 +233,7 @@ func (this *VMessRequest) ToBytes(timestampGenerator RandomTimestampGenerator, b
encryptionEnd += 4
timestampHash := md5.New()
timestampHash.Write(timestamp.HashBytes())
timestampHash.Write(hashTimestamp(timestamp))
iv := timestampHash.Sum(nil)
aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID.CmdKey(), iv)
if err != nil {

@ -17,10 +17,10 @@ import (
)
type FakeTimestampGenerator struct {
timestamp Timestamp
timestamp proto.Timestamp
}
func (this *FakeTimestampGenerator) Next() Timestamp {
func (this *FakeTimestampGenerator) Next() proto.Timestamp {
return this.timestamp
}
@ -36,8 +36,8 @@ func TestVMessSerialization(t *testing.T) {
ID: userId,
}
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)}
userSet.AddUser(testUser)
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
userSet.Add(testUser)
request := new(VMessRequest)
request.Version = byte(0x01)
@ -54,7 +54,7 @@ func TestVMessSerialization(t *testing.T) {
request.Address = v2net.DomainAddress("v2ray.com")
request.Port = v2net.Port(80)
mockTime := Timestamp(1823730)
mockTime := proto.Timestamp(1823730)
buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil)
if err != nil {
@ -92,12 +92,12 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
assert.Error(err).IsNil()
userId := proto.NewID(id)
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)}
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
testUser := &proto.User{
ID: userId,
}
userSet.AddUser(testUser)
userSet.Add(testUser)
request := new(VMessRequest)
request.Version = byte(0x01)
@ -114,6 +114,6 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
request.Port = v2net.Port(80)
for i := 0; i < b.N; i++ {
request.ToBytes(NewRandomTimestampGenerator(Timestamp(time.Now().Unix()), 30), nil)
request.ToBytes(NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), nil)
}
}

Loading…
Cancel
Save