mirror of https://github.com/v2ray/v2ray-core
Move userset to protocol
parent
2d82bb8d4d
commit
791ac307a2
|
@ -0,0 +1,123 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"hash"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
updateIntervalSec = 10
|
||||
cacheDurationSec = 120
|
||||
)
|
||||
|
||||
type IDHash func(key []byte) hash.Hash
|
||||
|
||||
type idEntry struct {
|
||||
id *ID
|
||||
userIdx int
|
||||
lastSec Timestamp
|
||||
lastSecRemoval Timestamp
|
||||
}
|
||||
|
||||
type UserValidator interface {
|
||||
Add(user *User) error
|
||||
Get(timeHash []byte) (*User, Timestamp, bool)
|
||||
}
|
||||
|
||||
type TimedUserValidator struct {
|
||||
validUsers []*User
|
||||
userHash map[[16]byte]*indexTimePair
|
||||
ids []*idEntry
|
||||
access sync.RWMutex
|
||||
hasher IDHash
|
||||
}
|
||||
|
||||
type indexTimePair struct {
|
||||
index int
|
||||
timeSec Timestamp
|
||||
}
|
||||
|
||||
func NewTimedUserValidator(hasher IDHash) UserValidator {
|
||||
tus := &TimedUserValidator{
|
||||
validUsers: make([]*User, 0, 16),
|
||||
userHash: make(map[[16]byte]*indexTimePair, 512),
|
||||
access: sync.RWMutex{},
|
||||
ids: make([]*idEntry, 0, 512),
|
||||
hasher: hasher,
|
||||
}
|
||||
go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
|
||||
return tus
|
||||
}
|
||||
|
||||
func (this *TimedUserValidator) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
|
||||
var hashValue [16]byte
|
||||
var hashValueRemoval [16]byte
|
||||
idHash := this.hasher(entry.id.Bytes())
|
||||
for entry.lastSec <= nowSec {
|
||||
idHash.Write(entry.lastSec.Bytes())
|
||||
idHash.Sum(hashValue[:0])
|
||||
idHash.Reset()
|
||||
|
||||
idHash.Write(entry.lastSecRemoval.Bytes())
|
||||
idHash.Sum(hashValueRemoval[:0])
|
||||
idHash.Reset()
|
||||
|
||||
this.access.Lock()
|
||||
this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
|
||||
delete(this.userHash, hashValueRemoval)
|
||||
this.access.Unlock()
|
||||
|
||||
entry.lastSec++
|
||||
entry.lastSecRemoval++
|
||||
}
|
||||
}
|
||||
|
||||
func (this *TimedUserValidator) updateUserHash(tick <-chan time.Time) {
|
||||
for now := range tick {
|
||||
nowSec := Timestamp(now.Unix() + cacheDurationSec)
|
||||
for _, entry := range this.ids {
|
||||
this.generateNewHashes(nowSec, entry.userIdx, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *TimedUserValidator) Add(user *User) error {
|
||||
idx := len(this.validUsers)
|
||||
this.validUsers = append(this.validUsers, user)
|
||||
|
||||
nowSec := time.Now().Unix()
|
||||
|
||||
entry := &idEntry{
|
||||
id: user.ID,
|
||||
userIdx: idx,
|
||||
lastSec: Timestamp(nowSec - cacheDurationSec),
|
||||
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
|
||||
}
|
||||
this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
|
||||
this.ids = append(this.ids, entry)
|
||||
for _, alterid := range user.AlterIDs {
|
||||
entry := &idEntry{
|
||||
id: alterid,
|
||||
userIdx: idx,
|
||||
lastSec: Timestamp(nowSec - cacheDurationSec),
|
||||
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
|
||||
}
|
||||
this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
|
||||
this.ids = append(this.ids, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TimedUserValidator) Get(userHash []byte) (*User, Timestamp, bool) {
|
||||
defer this.access.RUnlock()
|
||||
this.access.RLock()
|
||||
var fixedSizeHash [16]byte
|
||||
copy(fixedSizeHash[:], userHash)
|
||||
pair, found := this.userHash[fixedSizeHash]
|
||||
if found {
|
||||
return this.validUsers[pair.index], pair.timeSec, true
|
||||
}
|
||||
return nil, 0, false
|
||||
}
|
|
@ -66,7 +66,7 @@ type VMessInboundHandler struct {
|
|||
sync.Mutex
|
||||
packetDispatcher dispatcher.PacketDispatcher
|
||||
inboundHandlerManager proxyman.InboundHandlerManager
|
||||
clients protocol.UserSet
|
||||
clients proto.UserValidator
|
||||
usersByEmail *userByEmail
|
||||
accepting bool
|
||||
listener *hub.TCPHub
|
||||
|
@ -91,7 +91,7 @@ func (this *VMessInboundHandler) Close() {
|
|||
func (this *VMessInboundHandler) GetUser(email string) *proto.User {
|
||||
user, existing := this.usersByEmail.Get(email)
|
||||
if !existing {
|
||||
this.clients.AddUser(user)
|
||||
this.clients.Add(user)
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
@ -211,9 +211,9 @@ func init() {
|
|||
}
|
||||
config := rawConfig.(*Config)
|
||||
|
||||
allowedClients := protocol.NewTimedUserSet()
|
||||
allowedClients := proto.NewTimedUserValidator(protocol.IDHash)
|
||||
for _, user := range config.AllowedUsers {
|
||||
allowedClients.AddUser(user)
|
||||
allowedClients.Add(user)
|
||||
}
|
||||
|
||||
handler := &VMessInboundHandler{
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
v2io "github.com/v2ray/v2ray-core/common/io"
|
||||
"github.com/v2ray/v2ray-core/common/log"
|
||||
v2net "github.com/v2ray/v2ray-core/common/net"
|
||||
proto "github.com/v2ray/v2ray-core/common/protocol"
|
||||
"github.com/v2ray/v2ray-core/proxy"
|
||||
"github.com/v2ray/v2ray-core/proxy/internal"
|
||||
vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io"
|
||||
|
@ -106,7 +107,7 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol
|
|||
|
||||
buffer := alloc.NewBuffer().Clear()
|
||||
defer buffer.Release()
|
||||
buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(protocol.Timestamp(time.Now().Unix()), 30), buffer)
|
||||
buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), buffer)
|
||||
if err != nil {
|
||||
log.Error("VMessOut: Failed to serialize VMess request: ", err)
|
||||
return
|
||||
|
|
|
@ -2,25 +2,27 @@ package protocol
|
|||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/v2ray/v2ray-core/common/protocol"
|
||||
)
|
||||
|
||||
type RandomTimestampGenerator interface {
|
||||
Next() Timestamp
|
||||
Next() protocol.Timestamp
|
||||
}
|
||||
|
||||
type RealRandomTimestampGenerator struct {
|
||||
base Timestamp
|
||||
base protocol.Timestamp
|
||||
delta int
|
||||
}
|
||||
|
||||
func NewRandomTimestampGenerator(base Timestamp, delta int) RandomTimestampGenerator {
|
||||
func NewRandomTimestampGenerator(base protocol.Timestamp, delta int) RandomTimestampGenerator {
|
||||
return &RealRandomTimestampGenerator{
|
||||
base: base,
|
||||
delta: delta,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RealRandomTimestampGenerator) Next() Timestamp {
|
||||
func (this *RealRandomTimestampGenerator) Next() protocol.Timestamp {
|
||||
rangeInDelta := rand.Intn(this.delta*2) - this.delta
|
||||
return this.base + Timestamp(rangeInDelta)
|
||||
return this.base + protocol.Timestamp(rangeInDelta)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/v2ray/v2ray-core/common/protocol"
|
||||
. "github.com/v2ray/v2ray-core/proxy/vmess/protocol"
|
||||
v2testing "github.com/v2ray/v2ray-core/testing"
|
||||
"github.com/v2ray/v2ray-core/testing/assert"
|
||||
|
@ -14,7 +15,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
|
|||
|
||||
base := time.Now().Unix()
|
||||
delta := 100
|
||||
generator := NewRandomTimestampGenerator(Timestamp(base), delta)
|
||||
generator := NewRandomTimestampGenerator(protocol.Timestamp(base), delta)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
v := int64(generator.Next())
|
||||
|
|
|
@ -1,22 +1,21 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
proto "github.com/v2ray/v2ray-core/common/protocol"
|
||||
"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
|
||||
"github.com/v2ray/v2ray-core/common/protocol"
|
||||
)
|
||||
|
||||
type MockUserSet struct {
|
||||
Users []*proto.User
|
||||
Users []*protocol.User
|
||||
UserHashes map[string]int
|
||||
Timestamps map[string]protocol.Timestamp
|
||||
}
|
||||
|
||||
func (us *MockUserSet) AddUser(user *proto.User) error {
|
||||
func (us *MockUserSet) Add(user *protocol.User) error {
|
||||
us.Users = append(us.Users, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (us *MockUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) {
|
||||
func (us *MockUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
|
||||
idx, found := us.UserHashes[string(userhash)]
|
||||
if found {
|
||||
return us.Users[idx], us.Timestamps[string(userhash)], true
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
proto "github.com/v2ray/v2ray-core/common/protocol"
|
||||
"github.com/v2ray/v2ray-core/common/protocol"
|
||||
"github.com/v2ray/v2ray-core/common/uuid"
|
||||
"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
|
||||
)
|
||||
|
||||
type StaticUserSet struct {
|
||||
}
|
||||
|
||||
func (us *StaticUserSet) AddUser(user *proto.User) error {
|
||||
func (us *StaticUserSet) Add(user *protocol.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (us *StaticUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) {
|
||||
func (us *StaticUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
|
||||
id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21")
|
||||
return &proto.User{
|
||||
ID: proto.NewID(id),
|
||||
return &protocol.User{
|
||||
ID: protocol.NewID(id),
|
||||
}, 0, true
|
||||
}
|
||||
|
|
|
@ -1,137 +0,0 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
proto "github.com/v2ray/v2ray-core/common/protocol"
|
||||
"github.com/v2ray/v2ray-core/common/serial"
|
||||
)
|
||||
|
||||
const (
|
||||
updateIntervalSec = 10
|
||||
cacheDurationSec = 120
|
||||
)
|
||||
|
||||
type Timestamp int64
|
||||
|
||||
func (this Timestamp) Bytes() []byte {
|
||||
return serial.Int64Literal(this).Bytes()
|
||||
}
|
||||
|
||||
func (this Timestamp) HashBytes() []byte {
|
||||
once := this.Bytes()
|
||||
bytes := make([]byte, 0, 32)
|
||||
bytes = append(bytes, once...)
|
||||
bytes = append(bytes, once...)
|
||||
bytes = append(bytes, once...)
|
||||
bytes = append(bytes, once...)
|
||||
return bytes
|
||||
}
|
||||
|
||||
type idEntry struct {
|
||||
id *proto.ID
|
||||
userIdx int
|
||||
lastSec Timestamp
|
||||
lastSecRemoval Timestamp
|
||||
}
|
||||
|
||||
type UserSet interface {
|
||||
AddUser(user *proto.User) error
|
||||
GetUser(timeHash []byte) (*proto.User, Timestamp, bool)
|
||||
}
|
||||
|
||||
type TimedUserSet struct {
|
||||
validUsers []*proto.User
|
||||
userHash map[[16]byte]*indexTimePair
|
||||
ids []*idEntry
|
||||
access sync.RWMutex
|
||||
}
|
||||
|
||||
type indexTimePair struct {
|
||||
index int
|
||||
timeSec Timestamp
|
||||
}
|
||||
|
||||
func NewTimedUserSet() UserSet {
|
||||
tus := &TimedUserSet{
|
||||
validUsers: make([]*proto.User, 0, 16),
|
||||
userHash: make(map[[16]byte]*indexTimePair, 512),
|
||||
access: sync.RWMutex{},
|
||||
ids: make([]*idEntry, 0, 512),
|
||||
}
|
||||
go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
|
||||
return tus
|
||||
}
|
||||
|
||||
func (us *TimedUserSet) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
|
||||
var hashValue [16]byte
|
||||
var hashValueRemoval [16]byte
|
||||
idHash := IDHash(entry.id.Bytes())
|
||||
for entry.lastSec <= nowSec {
|
||||
idHash.Write(entry.lastSec.Bytes())
|
||||
idHash.Sum(hashValue[:0])
|
||||
idHash.Reset()
|
||||
|
||||
idHash.Write(entry.lastSecRemoval.Bytes())
|
||||
idHash.Sum(hashValueRemoval[:0])
|
||||
idHash.Reset()
|
||||
|
||||
us.access.Lock()
|
||||
us.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
|
||||
delete(us.userHash, hashValueRemoval)
|
||||
us.access.Unlock()
|
||||
|
||||
entry.lastSec++
|
||||
entry.lastSecRemoval++
|
||||
}
|
||||
}
|
||||
|
||||
func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
|
||||
for now := range tick {
|
||||
nowSec := Timestamp(now.Unix() + cacheDurationSec)
|
||||
for _, entry := range us.ids {
|
||||
us.generateNewHashes(nowSec, entry.userIdx, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (us *TimedUserSet) AddUser(user *proto.User) error {
|
||||
idx := len(us.validUsers)
|
||||
us.validUsers = append(us.validUsers, user)
|
||||
|
||||
nowSec := time.Now().Unix()
|
||||
|
||||
entry := &idEntry{
|
||||
id: user.ID,
|
||||
userIdx: idx,
|
||||
lastSec: Timestamp(nowSec - cacheDurationSec),
|
||||
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
|
||||
}
|
||||
us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
|
||||
us.ids = append(us.ids, entry)
|
||||
for _, alterid := range user.AlterIDs {
|
||||
entry := &idEntry{
|
||||
id: alterid,
|
||||
userIdx: idx,
|
||||
lastSec: Timestamp(nowSec - cacheDurationSec),
|
||||
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
|
||||
}
|
||||
us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
|
||||
us.ids = append(us.ids, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (us *TimedUserSet) GetUser(userHash []byte) (*proto.User, Timestamp, bool) {
|
||||
defer us.access.RUnlock()
|
||||
us.access.RLock()
|
||||
var fixedSizeHash [16]byte
|
||||
copy(fixedSizeHash[:], userHash)
|
||||
pair, found := us.userHash[fixedSizeHash]
|
||||
if found {
|
||||
return us.validUsers[pair.index], pair.timeSec, true
|
||||
}
|
||||
return nil, 0, false
|
||||
}
|
|
@ -31,6 +31,16 @@ const (
|
|||
blockSize = 16
|
||||
)
|
||||
|
||||
func hashTimestamp(t proto.Timestamp) []byte {
|
||||
once := t.Bytes()
|
||||
bytes := make([]byte, 0, 32)
|
||||
bytes = append(bytes, once...)
|
||||
bytes = append(bytes, once...)
|
||||
bytes = append(bytes, once...)
|
||||
bytes = append(bytes, once...)
|
||||
return bytes
|
||||
}
|
||||
|
||||
// VMessRequest implements the request message of VMess protocol. It only contains the header of a
|
||||
// request message. The data part will be handled by connection handler directly, in favor of data
|
||||
// streaming.
|
||||
|
@ -61,11 +71,11 @@ func (this *VMessRequest) IsChunkStream() bool {
|
|||
|
||||
// VMessRequestReader is a parser to read VMessRequest from a byte stream.
|
||||
type VMessRequestReader struct {
|
||||
vUserSet UserSet
|
||||
vUserSet proto.UserValidator
|
||||
}
|
||||
|
||||
// NewVMessRequestReader creates a new VMessRequestReader with a given UserSet
|
||||
func NewVMessRequestReader(vUserSet UserSet) *VMessRequestReader {
|
||||
func NewVMessRequestReader(vUserSet proto.UserValidator) *VMessRequestReader {
|
||||
return &VMessRequestReader{
|
||||
vUserSet: vUserSet,
|
||||
}
|
||||
|
@ -82,13 +92,13 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes])
|
||||
userObj, timeSec, valid := this.vUserSet.Get(buffer.Value[:nBytes])
|
||||
if !valid {
|
||||
return nil, proxy.ErrorInvalidAuthentication
|
||||
}
|
||||
|
||||
timestampHash := TimestampHash()
|
||||
timestampHash.Write(timeSec.HashBytes())
|
||||
timestampHash.Write(hashTimestamp(timeSec))
|
||||
iv := timestampHash.Sum(nil)
|
||||
aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv)
|
||||
if err != nil {
|
||||
|
@ -223,7 +233,7 @@ func (this *VMessRequest) ToBytes(timestampGenerator RandomTimestampGenerator, b
|
|||
encryptionEnd += 4
|
||||
|
||||
timestampHash := md5.New()
|
||||
timestampHash.Write(timestamp.HashBytes())
|
||||
timestampHash.Write(hashTimestamp(timestamp))
|
||||
iv := timestampHash.Sum(nil)
|
||||
aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID.CmdKey(), iv)
|
||||
if err != nil {
|
||||
|
|
|
@ -17,10 +17,10 @@ import (
|
|||
)
|
||||
|
||||
type FakeTimestampGenerator struct {
|
||||
timestamp Timestamp
|
||||
timestamp proto.Timestamp
|
||||
}
|
||||
|
||||
func (this *FakeTimestampGenerator) Next() Timestamp {
|
||||
func (this *FakeTimestampGenerator) Next() proto.Timestamp {
|
||||
return this.timestamp
|
||||
}
|
||||
|
||||
|
@ -36,8 +36,8 @@ func TestVMessSerialization(t *testing.T) {
|
|||
ID: userId,
|
||||
}
|
||||
|
||||
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)}
|
||||
userSet.AddUser(testUser)
|
||||
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
|
||||
userSet.Add(testUser)
|
||||
|
||||
request := new(VMessRequest)
|
||||
request.Version = byte(0x01)
|
||||
|
@ -54,7 +54,7 @@ func TestVMessSerialization(t *testing.T) {
|
|||
request.Address = v2net.DomainAddress("v2ray.com")
|
||||
request.Port = v2net.Port(80)
|
||||
|
||||
mockTime := Timestamp(1823730)
|
||||
mockTime := proto.Timestamp(1823730)
|
||||
|
||||
buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil)
|
||||
if err != nil {
|
||||
|
@ -92,12 +92,12 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
|
|||
assert.Error(err).IsNil()
|
||||
|
||||
userId := proto.NewID(id)
|
||||
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)}
|
||||
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
|
||||
|
||||
testUser := &proto.User{
|
||||
ID: userId,
|
||||
}
|
||||
userSet.AddUser(testUser)
|
||||
userSet.Add(testUser)
|
||||
|
||||
request := new(VMessRequest)
|
||||
request.Version = byte(0x01)
|
||||
|
@ -114,6 +114,6 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
|
|||
request.Port = v2net.Port(80)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
request.ToBytes(NewRandomTimestampGenerator(Timestamp(time.Now().Unix()), 30), nil)
|
||||
request.ToBytes(NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), nil)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue