From 010f34c76c2b9c9a9720090c32a1ac6f9fa87e8e Mon Sep 17 00:00:00 2001
From: Darien Raymond <admin@v2ray.com>
Date: Thu, 3 Nov 2016 23:14:27 +0100
Subject: [PATCH] allow single side auth

---
 .../internet/authenticators/http/config.proto |   3 +
 .../internet/authenticators/http/http.go      | 116 +++++++++++-------
 2 files changed, 77 insertions(+), 42 deletions(-)

diff --git a/transport/internet/authenticators/http/config.proto b/transport/internet/authenticators/http/config.proto
index d4b59321..144b97ce 100644
--- a/transport/internet/authenticators/http/config.proto
+++ b/transport/internet/authenticators/http/config.proto
@@ -53,6 +53,9 @@ message ResponseConfig {
 }
 
 message Config {
+  // Settings for authenticating requests. If not set, client side will not send authenication header, and server side will bypass authentication.
   RequestConfig request = 1;
+
+  // Settings for authenticating responses. If not set, client side will bypass authentication, and server side will not send authentication header.
   ResponseConfig response = 2;
 }
\ No newline at end of file
diff --git a/transport/internet/authenticators/http/http.go b/transport/internet/authenticators/http/http.go
index f2346d5c..767a7215 100644
--- a/transport/internet/authenticators/http/http.go
+++ b/transport/internet/authenticators/http/http.go
@@ -2,6 +2,7 @@ package http
 
 import (
 	"bytes"
+	"io"
 	"net"
 
 	"v2ray.com/core/common/alloc"
@@ -14,51 +15,73 @@ const (
 	ENDING = CRLF + CRLF
 )
 
+type HeaderReader struct {
+}
+
+func (*HeaderReader) Read(reader io.Reader) (*alloc.Buffer, error) {
+	buffer := alloc.NewLocalBuffer(2048)
+	for {
+		_, err := buffer.FillFrom(reader)
+		if err != nil {
+			return nil, err
+		}
+		if n := bytes.Index(buffer.Value, []byte(ENDING)); n != -1 {
+			buffer.SliceFrom(n + len(ENDING))
+			break
+		}
+		if buffer.Len() >= len(ENDING) {
+			copy(buffer.Value, buffer.Value[buffer.Len()-len(ENDING):])
+			buffer.Slice(0, len(ENDING))
+		}
+	}
+	return buffer, nil
+}
+
+type HeaderWriter struct {
+	header *alloc.Buffer
+}
+
+func (this *HeaderWriter) Write(writer io.Writer) error {
+	if this.header == nil {
+		return nil
+	}
+	_, err := writer.Write(this.header.Value)
+	this.header.Release()
+	this.header = nil
+	return err
+}
+
 type HttpConn struct {
 	net.Conn
 
-	buffer     *alloc.Buffer
-	readHeader bool
-
-	writeHeaderContent *alloc.Buffer
-	writeHeader        bool
+	readBuffer    *alloc.Buffer
+	oneTimeReader *HeaderReader
+	oneTimeWriter *HeaderWriter
 }
 
-func NewHttpConn(conn net.Conn, writeHeaderContent *alloc.Buffer) *HttpConn {
+func NewHttpConn(conn net.Conn, reader *HeaderReader, writer *HeaderWriter) *HttpConn {
 	return &HttpConn{
-		Conn:               conn,
-		readHeader:         true,
-		writeHeader:        true,
-		writeHeaderContent: writeHeaderContent,
+		Conn:          conn,
+		oneTimeReader: reader,
+		oneTimeWriter: writer,
 	}
 }
 
 func (this *HttpConn) Read(b []byte) (int, error) {
-	if this.readHeader {
-		buffer := alloc.NewLocalBuffer(2048)
-		for {
-			_, err := buffer.FillFrom(this.Conn)
-			if err != nil {
-				return 0, err
-			}
-			if n := bytes.Index(buffer.Value, []byte(ENDING)); n != -1 {
-				buffer.SliceFrom(n + len(ENDING))
-				break
-			}
-			if buffer.Len() >= len(ENDING) {
-				copy(buffer.Value, buffer.Value[buffer.Len()-len(ENDING):])
-				buffer.Slice(0, len(ENDING))
-			}
+	if this.oneTimeReader != nil {
+		buffer, err := this.oneTimeReader.Read(this.Conn)
+		if err != nil {
+			return 0, err
 		}
-		this.buffer = buffer
-		this.readHeader = false
+		this.readBuffer = buffer
+		this.oneTimeReader = nil
 	}
 
-	if this.buffer.Len() > 0 {
-		nBytes, err := this.buffer.Read(b)
-		if nBytes == this.buffer.Len() {
-			this.buffer.Release()
-			this.buffer = nil
+	if this.readBuffer.Len() > 0 {
+		nBytes, err := this.readBuffer.Read(b)
+		if nBytes == this.readBuffer.Len() {
+			this.readBuffer.Release()
+			this.readBuffer = nil
 		}
 		return nBytes, err
 	}
@@ -67,13 +90,12 @@ func (this *HttpConn) Read(b []byte) (int, error) {
 }
 
 func (this *HttpConn) Write(b []byte) (int, error) {
-	if this.writeHeader {
-		_, err := this.Conn.Write(this.writeHeaderContent.Value)
-		this.writeHeaderContent.Release()
+	if this.oneTimeWriter != nil {
+		err := this.oneTimeWriter.Write(this.Conn)
+		this.oneTimeWriter = nil
 		if err != nil {
 			return 0, err
 		}
-		this.writeHeader = false
 	}
 
 	return this.Conn.Write(b)
@@ -83,7 +105,7 @@ type HttpAuthenticator struct {
 	config *Config
 }
 
-func (this HttpAuthenticator) GetClientWriteHeader() *alloc.Buffer {
+func (this HttpAuthenticator) GetClientWriter() *HeaderWriter {
 	header := alloc.NewLocalBuffer(2048)
 	config := this.config.Request
 	header.AppendString(config.Method.GetValue()).AppendString(" ").AppendString(config.PickUri()).AppendString(" ").AppendString(config.GetFullVersion()).AppendString(CRLF)
@@ -93,10 +115,12 @@ func (this HttpAuthenticator) GetClientWriteHeader() *alloc.Buffer {
 		header.AppendString(h).AppendString(CRLF)
 	}
 	header.AppendString(CRLF)
-	return header
+	return &HeaderWriter{
+		header: header,
+	}
 }
 
-func (this HttpAuthenticator) GetServerWriteHeader() *alloc.Buffer {
+func (this HttpAuthenticator) GetServerWriter() *HeaderWriter {
 	header := alloc.NewLocalBuffer(2048)
 	config := this.config.Response
 	header.AppendString(config.GetFullVersion()).AppendString(" ").AppendString(config.Status.GetCode()).AppendString(" ").AppendString(config.Status.GetReason()).AppendString(CRLF)
@@ -106,15 +130,23 @@ func (this HttpAuthenticator) GetServerWriteHeader() *alloc.Buffer {
 		header.AppendString(h).AppendString(CRLF)
 	}
 	header.AppendString(CRLF)
-	return header
+	return &HeaderWriter{
+		header: header,
+	}
 }
 
 func (this HttpAuthenticator) Client(conn net.Conn) net.Conn {
-	return NewHttpConn(conn, this.GetClientWriteHeader())
+	if this.config.Request == nil && this.config.Response == nil {
+		return conn
+	}
+	return NewHttpConn(conn, new(HeaderReader), this.GetClientWriter())
 }
 
 func (this HttpAuthenticator) Server(conn net.Conn) net.Conn {
-	return NewHttpConn(conn, this.GetServerWriteHeader())
+	if this.config.Request == nil && this.config.Response == nil {
+		return conn
+	}
+	return NewHttpConn(conn, new(HeaderReader), this.GetServerWriter())
 }
 
 type HttpAuthenticatorFactory struct{}