mirror of https://github.com/v2ray/v2ray-core
				
				
				
			
							parent
							
								
									d83959569d
								
							
						
					
					
						commit
						2e94561584
					
				| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
package mtproto
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"io"
 | 
			
		||||
| 
						 | 
				
			
			@ -13,6 +14,35 @@ const (
 | 
			
		|||
	HeaderSize = 64
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SessionContext struct {
 | 
			
		||||
	ConnectionType [4]byte
 | 
			
		||||
	DataCenterID   uint16
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DefaultSessionContext() SessionContext {
 | 
			
		||||
	return SessionContext{
 | 
			
		||||
		ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
 | 
			
		||||
		DataCenterID:   0,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type contextKey int32
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	sessionContextKey contextKey = iota
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
 | 
			
		||||
	return context.WithValue(ctx, sessionContextKey, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SessionContextFromContext(ctx context.Context) SessionContext {
 | 
			
		||||
	if c := ctx.Value(sessionContextKey); c != nil {
 | 
			
		||||
		return c.(SessionContext)
 | 
			
		||||
	}
 | 
			
		||||
	return DefaultSessionContext()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Authentication struct {
 | 
			
		||||
	Header        [HeaderSize]byte
 | 
			
		||||
	DecodingKey   [32]byte
 | 
			
		||||
| 
						 | 
				
			
			@ -29,12 +59,18 @@ func (a *Authentication) DataCenterID() uint16 {
 | 
			
		|||
	return uint16(x) - 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authentication) ConnectionType() [4]byte {
 | 
			
		||||
	var x [4]byte
 | 
			
		||||
	copy(x[:], a.Header[56:60])
 | 
			
		||||
	return x
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authentication) ApplySecret(b []byte) {
 | 
			
		||||
	a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
 | 
			
		||||
	a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func generateRandomBytes(random []byte) {
 | 
			
		||||
func generateRandomBytes(random []byte, connType [4]byte) {
 | 
			
		||||
	for {
 | 
			
		||||
		common.Must2(rand.Read(random))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -51,19 +87,16 @@ func generateRandomBytes(random []byte) {
 | 
			
		|||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		random[56] = 0xef
 | 
			
		||||
		random[57] = 0xef
 | 
			
		||||
		random[58] = 0xef
 | 
			
		||||
		random[59] = 0xef
 | 
			
		||||
		copy(random[56:60], connType[:])
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAuthentication() *Authentication {
 | 
			
		||||
func NewAuthentication(sc SessionContext) *Authentication {
 | 
			
		||||
	auth := getAuthenticationObject()
 | 
			
		||||
	random := auth.Header[:]
 | 
			
		||||
	generateRandomBytes(random)
 | 
			
		||||
	generateRandomBytes(random, sc.ConnectionType)
 | 
			
		||||
	copy(auth.EncodingKey[:], random[8:])
 | 
			
		||||
	copy(auth.EncodingNonce[:], random[8+32:])
 | 
			
		||||
	keyivInverse := Inverse(random[8 : 8+32+16])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ func TestInverse(t *testing.T) {
 | 
			
		|||
func TestAuthenticationReadWrite(t *testing.T) {
 | 
			
		||||
	assert := With(t)
 | 
			
		||||
 | 
			
		||||
	a := NewAuthentication()
 | 
			
		||||
	a := NewAuthentication(DefaultSessionContext())
 | 
			
		||||
	b := bytes.NewReader(a.Header[:])
 | 
			
		||||
	a2, err := ReadAuthentication(b)
 | 
			
		||||
	assert(err, IsNil)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,7 +36,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 | 
			
		|||
	}
 | 
			
		||||
	defer conn.Close() // nolint: errcheck
 | 
			
		||||
 | 
			
		||||
	auth := NewAuthentication()
 | 
			
		||||
	sc := SessionContextFromContext(ctx)
 | 
			
		||||
	auth := NewAuthentication(sc)
 | 
			
		||||
	defer putAuthenticationObject(auth)
 | 
			
		||||
 | 
			
		||||
	request := func() error {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -64,6 +64,16 @@ func (s *Server) Network() net.NetworkList {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidConnectionType(c [4]byte) bool {
 | 
			
		||||
	if compare.BytesAll(c[:], 0xef) {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if compare.BytesAll(c[:], 0xee) {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error {
 | 
			
		||||
	sPolicy := s.policy.ForLevel(s.user.Level)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -85,8 +95,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 | 
			
		|||
	decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
 | 
			
		||||
	decryptor.XORKeyStream(auth.Header[:], auth.Header[:])
 | 
			
		||||
 | 
			
		||||
	if !compare.BytesAll(auth.Header[56:60], 0xef) {
 | 
			
		||||
		return newError("invalid connection type: ", auth.Header[56:60])
 | 
			
		||||
	ct := auth.ConnectionType()
 | 
			
		||||
	if !isValidConnectionType(ct) {
 | 
			
		||||
		return newError("invalid connection type: ", ct)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dcID := auth.DataCenterID()
 | 
			
		||||
| 
						 | 
				
			
			@ -104,6 +115,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 | 
			
		|||
	timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle)
 | 
			
		||||
	ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer)
 | 
			
		||||
 | 
			
		||||
	sc := SessionContext{
 | 
			
		||||
		ConnectionType: ct,
 | 
			
		||||
		DataCenterID:   dcID,
 | 
			
		||||
	}
 | 
			
		||||
	ctx = ContextWithSessionContext(ctx, sc)
 | 
			
		||||
 | 
			
		||||
	link, err := dispatcher.Dispatch(ctx, dest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return newError("failed to dispatch request to: ", dest).Base(err)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue