From 57ff7ba923ad4cf12a405e2a3fcc6bf5e947fb20 Mon Sep 17 00:00:00 2001
From: Darien Raymond <admin@v2ray.com>
Date: Fri, 29 Jan 2016 15:43:45 +0000
Subject: [PATCH] complete implementation of shadowsocks ota

---
 common/io/reader.go                |  4 +--
 common/serial/numbers.go           | 10 ++++++
 proxy/shadowsocks/ota.go           | 55 ++++++++++++++++++++++++++++++
 proxy/shadowsocks/protocol.go      | 45 ++++++++++++++++++------
 proxy/shadowsocks/protocol_test.go |  9 +++--
 proxy/shadowsocks/shadowsocks.go   | 35 ++++++++++++++-----
 6 files changed, 134 insertions(+), 24 deletions(-)
 create mode 100644 proxy/shadowsocks/ota.go

diff --git a/common/io/reader.go b/common/io/reader.go
index 33bc7aa0..fd7c073c 100644
--- a/common/io/reader.go
+++ b/common/io/reader.go
@@ -87,9 +87,9 @@ type AuthenticationReader struct {
 	authBeforePayload bool
 }
 
-func NewAuthenticationReader(reader io.Reader, auth crypto.Authenticator, authBeforePayload bool) *AuthenticationReader {
+func NewAuthenticationReader(reader Reader, auth crypto.Authenticator, authBeforePayload bool) *AuthenticationReader {
 	return &AuthenticationReader{
-		reader:            NewChunkReader(reader),
+		reader:            reader,
 		authenticator:     auth,
 		authBeforePayload: authBeforePayload,
 	}
diff --git a/common/serial/numbers.go b/common/serial/numbers.go
index 9de8759c..d186ba60 100644
--- a/common/serial/numbers.go
+++ b/common/serial/numbers.go
@@ -36,6 +36,16 @@ func (this IntLiteral) Value() int {
 	return int(this)
 }
 
+func (this IntLiteral) Bytes() []byte {
+	value := this.Value()
+	return []byte{
+		byte(value >> 24),
+		byte(value >> 16),
+		byte(value >> 8),
+		byte(value),
+	}
+}
+
 type Int64Literal int64
 
 func (this Int64Literal) String() string {
diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go
new file mode 100644
index 00000000..a30dd3d0
--- /dev/null
+++ b/proxy/shadowsocks/ota.go
@@ -0,0 +1,55 @@
+package shadowsocks
+
+import (
+	"crypto/hmac"
+	"crypto/sha1"
+
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+const (
+	AuthSize = 10
+)
+
+type KeyGenerator func() []byte
+
+type Authenticator struct {
+	key KeyGenerator
+}
+
+func NewAuthenticator(keygen KeyGenerator) *Authenticator {
+	return &Authenticator{
+		key: keygen,
+	}
+}
+
+func (this *Authenticator) AuthSize() int {
+	return AuthSize
+}
+
+func (this *Authenticator) Authenticate(auth []byte, data []byte) []byte {
+	hasher := hmac.New(sha1.New, this.key())
+	hasher.Write(data)
+	res := hasher.Sum(nil)
+	return append(auth, res[:AuthSize]...)
+}
+
+func HeaderKeyGenerator(key []byte, iv []byte) func() []byte {
+	return func() []byte {
+		newKey := make([]byte, 0, len(key)+len(iv))
+		newKey = append(newKey, key...)
+		newKey = append(newKey, iv...)
+		return newKey
+	}
+}
+
+func ChunkKeyGenerator(iv []byte) func() []byte {
+	chunkId := 0
+	return func() []byte {
+		newKey := make([]byte, 0, len(iv)+4)
+		newKey = append(newKey, iv...)
+		newKey = append(newKey, serial.IntLiteral(chunkId).Bytes()...)
+		chunkId++
+		return newKey
+	}
+}
diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go
index 079edf84..bc3af8df 100644
--- a/proxy/shadowsocks/protocol.go
+++ b/proxy/shadowsocks/protocol.go
@@ -6,6 +6,7 @@ import (
 	"github.com/v2ray/v2ray-core/common/alloc"
 	"github.com/v2ray/v2ray-core/common/log"
 	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/common/serial"
 	"github.com/v2ray/v2ray-core/transport"
 )
 
@@ -21,7 +22,7 @@ type Request struct {
 	OTA     bool
 }
 
-func ReadRequest(reader io.Reader) (*Request, error) {
+func ReadRequest(reader io.Reader, auth *Authenticator) (*Request, error) {
 	buffer := alloc.NewSmallBuffer()
 	defer buffer.Release()
 
@@ -30,6 +31,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
 		log.Error("Shadowsocks: Failed to read address type: ", err)
 		return nil, transport.CorruptedPacket
 	}
+	lenBuffer := 1
 
 	request := new(Request)
 
@@ -39,43 +41,64 @@ func ReadRequest(reader io.Reader) (*Request, error) {
 	}
 	switch addrType {
 	case AddrTypeIPv4:
-		_, err := io.ReadFull(reader, buffer.Value[:4])
+		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+4])
 		if err != nil {
 			log.Error("Shadowsocks: Failed to read IPv4 address: ", err)
 			return nil, transport.CorruptedPacket
 		}
-		request.Address = v2net.IPAddress(buffer.Value[:4])
+		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+4])
+		lenBuffer += 4
 	case AddrTypeIPv6:
-		_, err := io.ReadFull(reader, buffer.Value[:16])
+		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+16])
 		if err != nil {
 			log.Error("Shadowsocks: Failed to read IPv6 address: ", err)
 			return nil, transport.CorruptedPacket
 		}
-		request.Address = v2net.IPAddress(buffer.Value[:16])
+		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+16])
+		lenBuffer += 16
 	case AddrTypeDomain:
-		_, err := io.ReadFull(reader, buffer.Value[:1])
+		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+1])
 		if err != nil {
 			log.Error("Shadowsocks: Failed to read domain lenth: ", err)
 			return nil, transport.CorruptedPacket
 		}
-		domainLength := int(buffer.Value[0])
-		_, err = io.ReadFull(reader, buffer.Value[:domainLength])
+		domainLength := int(buffer.Value[lenBuffer])
+		lenBuffer++
+		_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+domainLength])
 		if err != nil {
 			log.Error("Shadowsocks: Failed to read domain: ", err)
 			return nil, transport.CorruptedPacket
 		}
-		request.Address = v2net.DomainAddress(string(buffer.Value[:domainLength]))
+		request.Address = v2net.DomainAddress(string(buffer.Value[lenBuffer : lenBuffer+domainLength]))
+		lenBuffer += domainLength
 	default:
 		log.Error("Shadowsocks: Unknown address type: ", addrType)
 		return nil, transport.CorruptedPacket
 	}
 
-	_, err = io.ReadFull(reader, buffer.Value[:2])
+	_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+2])
 	if err != nil {
 		log.Error("Shadowsocks: Failed to read port: ", err)
 		return nil, transport.CorruptedPacket
 	}
 
-	request.Port = v2net.PortFromBytes(buffer.Value[:2])
+	request.Port = v2net.PortFromBytes(buffer.Value[lenBuffer : lenBuffer+2])
+	lenBuffer += 2
+
+	if request.OTA {
+		authBytes := buffer.Value[lenBuffer : lenBuffer+auth.AuthSize()]
+		_, err = io.ReadFull(reader, authBytes)
+		if err != nil {
+			log.Error("Shadowsocks: Failed to read OTA: ", err)
+			return nil, transport.CorruptedPacket
+		}
+
+		actualAuth := auth.Authenticate(nil, buffer.Value[0:lenBuffer])
+		if !serial.BytesLiteral(actualAuth).Equals(serial.BytesLiteral(authBytes)) {
+			log.Error("Shadowsocks: Invalid OTA: ", actualAuth)
+			return nil, transport.CorruptedPacket
+		}
+	}
+
 	return request, nil
 }
diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go
index b230b22d..2b5658d2 100644
--- a/proxy/shadowsocks/protocol_test.go
+++ b/proxy/shadowsocks/protocol_test.go
@@ -17,7 +17,7 @@ func TestNormalRequestParsing(t *testing.T) {
 	buffer := alloc.NewSmallBuffer().Clear()
 	buffer.AppendBytes(1, 127, 0, 0, 1, 0, 80)
 
-	request, err := ReadRequest(buffer)
+	request, err := ReadRequest(buffer, nil)
 	assert.Error(err).IsNil()
 	netassert.Address(request.Address).Equals(v2net.IPAddress([]byte{127, 0, 0, 1}))
 	netassert.Port(request.Port).Equals(v2net.Port(80))
@@ -28,9 +28,12 @@ func TestOTARequest(t *testing.T) {
 	v2testing.Current(t)
 
 	buffer := alloc.NewSmallBuffer().Clear()
-	buffer.AppendBytes(0x13, 13, 119, 119, 119, 46, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 0)
+	buffer.AppendBytes(0x13, 13, 119, 119, 119, 46, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 0, 239, 115, 52, 212, 178, 172, 26, 6, 168, 0)
 
-	request, err := ReadRequest(buffer)
+	auth := NewAuthenticator(HeaderKeyGenerator(
+		[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5},
+		[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5}))
+	request, err := ReadRequest(buffer, auth)
 	assert.Error(err).IsNil()
 	netassert.Address(request.Address).Equals(v2net.DomainAddress("www.v2ray.com"))
 	assert.Bool(request.OTA).IsTrue()
diff --git a/proxy/shadowsocks/shadowsocks.go b/proxy/shadowsocks/shadowsocks.go
index e6f207c5..bd6832e8 100644
--- a/proxy/shadowsocks/shadowsocks.go
+++ b/proxy/shadowsocks/shadowsocks.go
@@ -32,11 +32,16 @@ func (this *Shadowsocks) Port() v2net.Port {
 
 func (this *Shadowsocks) Close() {
 	this.accepting = false
-	this.tcpHub.Close()
-	this.tcpHub = nil
+	if this.tcpHub != nil {
+		this.tcpHub.Close()
+		this.tcpHub = nil
+	}
+
+	if this.udpHub != nil {
+		this.udpHub.Close()
+		this.udpHub = nil
+	}
 
-	this.udpHub.Close()
-	this.udpHub = nil
 }
 
 func (this *Shadowsocks) Listen(port v2net.Port) error {
@@ -80,7 +85,7 @@ func (this *Shadowsocks) handlerUDPPayload(payload *alloc.Buffer, dest v2net.Des
 		return
 	}
 
-	request, err := ReadRequest(reader)
+	request, err := ReadRequest(reader, NewAuthenticator(HeaderKeyGenerator(key, iv)))
 	if err != nil {
 		return
 	}
@@ -95,8 +100,9 @@ func (this *Shadowsocks) handlerUDPPayload(payload *alloc.Buffer, dest v2net.Des
 
 		response := alloc.NewBuffer().Slice(0, this.config.Cipher.IVSize())
 		rand.Read(response.Value)
+		respIv := response.Value
 
-		writer, err := this.config.Cipher.NewEncodingStream(key, response.Value, response)
+		writer, err := this.config.Cipher.NewEncodingStream(key, respIv, response)
 		if err != nil {
 			log.Error("Shadowsocks: Failed to create encoding stream: ", err)
 			return
@@ -118,6 +124,11 @@ func (this *Shadowsocks) handlerUDPPayload(payload *alloc.Buffer, dest v2net.Des
 		writer.Write(respChunk.Value)
 		respChunk.Release()
 
+		if request.OTA {
+			respAuth := NewAuthenticator(HeaderKeyGenerator(key, respIv))
+			respAuth.Authenticate(buffer.Value, buffer.Value[this.config.Cipher.IVSize():])
+		}
+
 		this.udpHub.WriteTo(response.Value, dest)
 		response.Release()
 	}
@@ -144,7 +155,7 @@ func (this *Shadowsocks) handleConnection(conn *hub.TCPConn) {
 		return
 	}
 
-	request, err := ReadRequest(reader)
+	request, err := ReadRequest(reader, NewAuthenticator(HeaderKeyGenerator(key, iv)))
 	if err != nil {
 		return
 	}
@@ -174,7 +185,15 @@ func (this *Shadowsocks) handleConnection(conn *hub.TCPConn) {
 		writeFinish.Unlock()
 	}()
 
-	v2io.RawReaderToChan(ray.InboundInput(), reader)
+	var payloadReader v2io.Reader
+	if request.OTA {
+		payloadAuth := NewAuthenticator(ChunkKeyGenerator(iv))
+		payloadReader = v2io.NewAuthenticationReader(v2io.NewChunkReader(reader), payloadAuth, true)
+	} else {
+		payloadReader = v2io.NewAdaptiveReader(reader)
+	}
+
+	v2io.ReaderToChan(ray.InboundInput(), payloadReader)
 	close(ray.InboundInput())
 
 	writeFinish.Lock()