support mtproto conn type 0xee. fixes #1297

pull/1331/head^2 v3.47
Darien Raymond 6 years ago
parent d83959569d
commit 2e94561584
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -1,6 +1,7 @@
package mtproto package mtproto
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"io" "io"
@ -13,6 +14,35 @@ const (
HeaderSize = 64 HeaderSize = 64
) )
type SessionContext struct {
ConnectionType [4]byte
DataCenterID uint16
}
func DefaultSessionContext() SessionContext {
return SessionContext{
ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
DataCenterID: 0,
}
}
type contextKey int32
const (
sessionContextKey contextKey = iota
)
func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
return context.WithValue(ctx, sessionContextKey, c)
}
func SessionContextFromContext(ctx context.Context) SessionContext {
if c := ctx.Value(sessionContextKey); c != nil {
return c.(SessionContext)
}
return DefaultSessionContext()
}
type Authentication struct { type Authentication struct {
Header [HeaderSize]byte Header [HeaderSize]byte
DecodingKey [32]byte DecodingKey [32]byte
@ -29,12 +59,18 @@ func (a *Authentication) DataCenterID() uint16 {
return uint16(x) - 1 return uint16(x) - 1
} }
func (a *Authentication) ConnectionType() [4]byte {
var x [4]byte
copy(x[:], a.Header[56:60])
return x
}
func (a *Authentication) ApplySecret(b []byte) { func (a *Authentication) ApplySecret(b []byte) {
a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...)) a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...)) a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
} }
func generateRandomBytes(random []byte) { func generateRandomBytes(random []byte, connType [4]byte) {
for { for {
common.Must2(rand.Read(random)) common.Must2(rand.Read(random))
@ -51,19 +87,16 @@ func generateRandomBytes(random []byte) {
continue continue
} }
random[56] = 0xef copy(random[56:60], connType[:])
random[57] = 0xef
random[58] = 0xef
random[59] = 0xef
return return
} }
} }
func NewAuthentication() *Authentication { func NewAuthentication(sc SessionContext) *Authentication {
auth := getAuthenticationObject() auth := getAuthenticationObject()
random := auth.Header[:] random := auth.Header[:]
generateRandomBytes(random) generateRandomBytes(random, sc.ConnectionType)
copy(auth.EncodingKey[:], random[8:]) copy(auth.EncodingKey[:], random[8:])
copy(auth.EncodingNonce[:], random[8+32:]) copy(auth.EncodingNonce[:], random[8+32:])
keyivInverse := Inverse(random[8 : 8+32+16]) keyivInverse := Inverse(random[8 : 8+32+16])

@ -32,7 +32,7 @@ func TestInverse(t *testing.T) {
func TestAuthenticationReadWrite(t *testing.T) { func TestAuthenticationReadWrite(t *testing.T) {
assert := With(t) assert := With(t)
a := NewAuthentication() a := NewAuthentication(DefaultSessionContext())
b := bytes.NewReader(a.Header[:]) b := bytes.NewReader(a.Header[:])
a2, err := ReadAuthentication(b) a2, err := ReadAuthentication(b)
assert(err, IsNil) assert(err, IsNil)

@ -36,7 +36,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
} }
defer conn.Close() // nolint: errcheck defer conn.Close() // nolint: errcheck
auth := NewAuthentication() sc := SessionContextFromContext(ctx)
auth := NewAuthentication(sc)
defer putAuthenticationObject(auth) defer putAuthenticationObject(auth)
request := func() error { request := func() error {

@ -64,6 +64,16 @@ func (s *Server) Network() net.NetworkList {
} }
} }
func isValidConnectionType(c [4]byte) bool {
if compare.BytesAll(c[:], 0xef) {
return true
}
if compare.BytesAll(c[:], 0xee) {
return true
}
return false
}
func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error { func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error {
sPolicy := s.policy.ForLevel(s.user.Level) sPolicy := s.policy.ForLevel(s.user.Level)
@ -85,8 +95,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:]) decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
decryptor.XORKeyStream(auth.Header[:], auth.Header[:]) decryptor.XORKeyStream(auth.Header[:], auth.Header[:])
if !compare.BytesAll(auth.Header[56:60], 0xef) { ct := auth.ConnectionType()
return newError("invalid connection type: ", auth.Header[56:60]) if !isValidConnectionType(ct) {
return newError("invalid connection type: ", ct)
} }
dcID := auth.DataCenterID() dcID := auth.DataCenterID()
@ -104,6 +115,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle) timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle)
ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer) ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer)
sc := SessionContext{
ConnectionType: ct,
DataCenterID: dcID,
}
ctx = ContextWithSessionContext(ctx, sc)
link, err := dispatcher.Dispatch(ctx, dest) link, err := dispatcher.Dispatch(ctx, dest)
if err != nil { if err != nil {
return newError("failed to dispatch request to: ", dest).Base(err) return newError("failed to dispatch request to: ", dest).Base(err)

Loading…
Cancel
Save