From 010f34c76c2b9c9a9720090c32a1ac6f9fa87e8e Mon Sep 17 00:00:00 2001 From: Darien Raymond <admin@v2ray.com> Date: Thu, 3 Nov 2016 23:14:27 +0100 Subject: [PATCH] allow single side auth --- .../internet/authenticators/http/config.proto | 3 + .../internet/authenticators/http/http.go | 116 +++++++++++------- 2 files changed, 77 insertions(+), 42 deletions(-) diff --git a/transport/internet/authenticators/http/config.proto b/transport/internet/authenticators/http/config.proto index d4b59321..144b97ce 100644 --- a/transport/internet/authenticators/http/config.proto +++ b/transport/internet/authenticators/http/config.proto @@ -53,6 +53,9 @@ message ResponseConfig { } message Config { + // Settings for authenticating requests. If not set, client side will not send authenication header, and server side will bypass authentication. RequestConfig request = 1; + + // Settings for authenticating responses. If not set, client side will bypass authentication, and server side will not send authentication header. ResponseConfig response = 2; } \ No newline at end of file diff --git a/transport/internet/authenticators/http/http.go b/transport/internet/authenticators/http/http.go index f2346d5c..767a7215 100644 --- a/transport/internet/authenticators/http/http.go +++ b/transport/internet/authenticators/http/http.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "io" "net" "v2ray.com/core/common/alloc" @@ -14,51 +15,73 @@ const ( ENDING = CRLF + CRLF ) +type HeaderReader struct { +} + +func (*HeaderReader) Read(reader io.Reader) (*alloc.Buffer, error) { + buffer := alloc.NewLocalBuffer(2048) + for { + _, err := buffer.FillFrom(reader) + if err != nil { + return nil, err + } + if n := bytes.Index(buffer.Value, []byte(ENDING)); n != -1 { + buffer.SliceFrom(n + len(ENDING)) + break + } + if buffer.Len() >= len(ENDING) { + copy(buffer.Value, buffer.Value[buffer.Len()-len(ENDING):]) + buffer.Slice(0, len(ENDING)) + } + } + return buffer, nil +} + +type HeaderWriter struct { + header *alloc.Buffer +} + +func (this *HeaderWriter) Write(writer io.Writer) error { + if this.header == nil { + return nil + } + _, err := writer.Write(this.header.Value) + this.header.Release() + this.header = nil + return err +} + type HttpConn struct { net.Conn - buffer *alloc.Buffer - readHeader bool - - writeHeaderContent *alloc.Buffer - writeHeader bool + readBuffer *alloc.Buffer + oneTimeReader *HeaderReader + oneTimeWriter *HeaderWriter } -func NewHttpConn(conn net.Conn, writeHeaderContent *alloc.Buffer) *HttpConn { +func NewHttpConn(conn net.Conn, reader *HeaderReader, writer *HeaderWriter) *HttpConn { return &HttpConn{ - Conn: conn, - readHeader: true, - writeHeader: true, - writeHeaderContent: writeHeaderContent, + Conn: conn, + oneTimeReader: reader, + oneTimeWriter: writer, } } func (this *HttpConn) Read(b []byte) (int, error) { - if this.readHeader { - buffer := alloc.NewLocalBuffer(2048) - for { - _, err := buffer.FillFrom(this.Conn) - if err != nil { - return 0, err - } - if n := bytes.Index(buffer.Value, []byte(ENDING)); n != -1 { - buffer.SliceFrom(n + len(ENDING)) - break - } - if buffer.Len() >= len(ENDING) { - copy(buffer.Value, buffer.Value[buffer.Len()-len(ENDING):]) - buffer.Slice(0, len(ENDING)) - } + if this.oneTimeReader != nil { + buffer, err := this.oneTimeReader.Read(this.Conn) + if err != nil { + return 0, err } - this.buffer = buffer - this.readHeader = false + this.readBuffer = buffer + this.oneTimeReader = nil } - if this.buffer.Len() > 0 { - nBytes, err := this.buffer.Read(b) - if nBytes == this.buffer.Len() { - this.buffer.Release() - this.buffer = nil + if this.readBuffer.Len() > 0 { + nBytes, err := this.readBuffer.Read(b) + if nBytes == this.readBuffer.Len() { + this.readBuffer.Release() + this.readBuffer = nil } return nBytes, err } @@ -67,13 +90,12 @@ func (this *HttpConn) Read(b []byte) (int, error) { } func (this *HttpConn) Write(b []byte) (int, error) { - if this.writeHeader { - _, err := this.Conn.Write(this.writeHeaderContent.Value) - this.writeHeaderContent.Release() + if this.oneTimeWriter != nil { + err := this.oneTimeWriter.Write(this.Conn) + this.oneTimeWriter = nil if err != nil { return 0, err } - this.writeHeader = false } return this.Conn.Write(b) @@ -83,7 +105,7 @@ type HttpAuthenticator struct { config *Config } -func (this HttpAuthenticator) GetClientWriteHeader() *alloc.Buffer { +func (this HttpAuthenticator) GetClientWriter() *HeaderWriter { header := alloc.NewLocalBuffer(2048) config := this.config.Request header.AppendString(config.Method.GetValue()).AppendString(" ").AppendString(config.PickUri()).AppendString(" ").AppendString(config.GetFullVersion()).AppendString(CRLF) @@ -93,10 +115,12 @@ func (this HttpAuthenticator) GetClientWriteHeader() *alloc.Buffer { header.AppendString(h).AppendString(CRLF) } header.AppendString(CRLF) - return header + return &HeaderWriter{ + header: header, + } } -func (this HttpAuthenticator) GetServerWriteHeader() *alloc.Buffer { +func (this HttpAuthenticator) GetServerWriter() *HeaderWriter { header := alloc.NewLocalBuffer(2048) config := this.config.Response header.AppendString(config.GetFullVersion()).AppendString(" ").AppendString(config.Status.GetCode()).AppendString(" ").AppendString(config.Status.GetReason()).AppendString(CRLF) @@ -106,15 +130,23 @@ func (this HttpAuthenticator) GetServerWriteHeader() *alloc.Buffer { header.AppendString(h).AppendString(CRLF) } header.AppendString(CRLF) - return header + return &HeaderWriter{ + header: header, + } } func (this HttpAuthenticator) Client(conn net.Conn) net.Conn { - return NewHttpConn(conn, this.GetClientWriteHeader()) + if this.config.Request == nil && this.config.Response == nil { + return conn + } + return NewHttpConn(conn, new(HeaderReader), this.GetClientWriter()) } func (this HttpAuthenticator) Server(conn net.Conn) net.Conn { - return NewHttpConn(conn, this.GetServerWriteHeader()) + if this.config.Request == nil && this.config.Response == nil { + return conn + } + return NewHttpConn(conn, new(HeaderReader), this.GetServerWriter()) } type HttpAuthenticatorFactory struct{}