From b64aceabcf6207643e812efe0d43b7dec5b7a44e Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 26 Nov 2017 16:56:01 +0100 Subject: [PATCH] fix aead reader and writer --- proxy/shadowsocks/protocol.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 703c33c4..e61cb7c8 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -43,7 +43,8 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea if err != nil { return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError() } - reader = r.(io.Reader) + br := buf.NewBufferedReader(r) + reader = nil authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) request := &protocol.RequestHeader{ @@ -52,7 +53,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea Command: protocol.RequestCommandTCP, } - if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil { + if err := buffer.Reset(buf.ReadFullFrom(br, 1)); err != nil { return nil, nil, newError("failed to read address type").Base(err) } @@ -73,21 +74,21 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea addrType := (buffer.Byte(0) & 0x0F) switch addrType { case AddrTypeIPv4: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { + if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 4)); err != nil { return nil, nil, newError("failed to read IPv4 address").Base(err) } request.Address = net.IPAddress(buffer.BytesFrom(-4)) case AddrTypeIPv6: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { + if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 16)); err != nil { return nil, nil, newError("failed to read IPv6 address").Base(err) } request.Address = net.IPAddress(buffer.BytesFrom(-16)) case AddrTypeDomain: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { + if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 1)); err != nil { return nil, nil, newError("failed to read domain lenth.").Base(err) } domainLength := int(buffer.BytesFrom(-1)[0]) - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)) + err = buffer.AppendSupplier(buf.ReadFullFrom(br, domainLength)) if err != nil { return nil, nil, newError("failed to read domain").Base(err) } @@ -96,7 +97,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea // Check address validity after OTA verification. } - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)) + err = buffer.AppendSupplier(buf.ReadFullFrom(br, 2)) if err != nil { return nil, nil, newError("failed to read port").Base(err) } @@ -106,7 +107,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea actualAuth := make([]byte, AuthSize) authenticator.Authenticate(buffer.Bytes())(actualAuth) - err := buffer.AppendSupplier(buf.ReadFullFrom(reader, AuthSize)) + err := buffer.AppendSupplier(buf.ReadFullFrom(br, AuthSize)) if err != nil { return nil, nil, newError("Failed to read OTA").Base(err) } @@ -120,11 +121,13 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea return nil, nil, newError("invalid remote address.") } + br.SetBuffered(false) + var chunkReader buf.Reader if request.Option.Has(RequestOptionOneTimeAuth) { - chunkReader = NewChunkReader(reader, NewAuthenticator(ChunkKeyGenerator(iv))) + chunkReader = NewChunkReader(br, NewAuthenticator(ChunkKeyGenerator(iv))) } else { - chunkReader = buf.NewReader(reader) + chunkReader = buf.NewReader(br) } return request, chunkReader, nil @@ -154,8 +157,6 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri return nil, newError("failed to create encoding stream").Base(err).AtError() } - writer = w.(io.Writer) - header := buf.NewLocal(512) switch request.Address.Family() { @@ -185,16 +186,15 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes()))) } - _, err = writer.Write(header.Bytes()) - if err != nil { + if err := w.WriteMultiBuffer(buf.NewMultiBufferValue(header)); err != nil { return nil, newError("failed to write header").Base(err) } var chunkWriter buf.Writer if request.Option.Has(RequestOptionOneTimeAuth) { - chunkWriter = NewChunkWriter(writer, NewAuthenticator(ChunkKeyGenerator(iv))) + chunkWriter = NewChunkWriter(w.(io.Writer), NewAuthenticator(ChunkKeyGenerator(iv))) } else { - chunkWriter = buf.NewWriter(writer) + chunkWriter = w } return chunkWriter, nil