mirror of https://github.com/v2ray/v2ray-core
fix aead reader and writer
parent
40222de0f7
commit
b64aceabcf
|
@ -43,7 +43,8 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
|
return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
|
||||||
}
|
}
|
||||||
reader = r.(io.Reader)
|
br := buf.NewBufferedReader(r)
|
||||||
|
reader = nil
|
||||||
|
|
||||||
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
|
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
|
||||||
request := &protocol.RequestHeader{
|
request := &protocol.RequestHeader{
|
||||||
|
@ -52,7 +53,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
||||||
Command: protocol.RequestCommandTCP,
|
Command: protocol.RequestCommandTCP,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil {
|
if err := buffer.Reset(buf.ReadFullFrom(br, 1)); err != nil {
|
||||||
return nil, nil, newError("failed to read address type").Base(err)
|
return nil, nil, newError("failed to read address type").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,21 +74,21 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
||||||
addrType := (buffer.Byte(0) & 0x0F)
|
addrType := (buffer.Byte(0) & 0x0F)
|
||||||
switch addrType {
|
switch addrType {
|
||||||
case AddrTypeIPv4:
|
case AddrTypeIPv4:
|
||||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 4)); err != nil {
|
||||||
return nil, nil, newError("failed to read IPv4 address").Base(err)
|
return nil, nil, newError("failed to read IPv4 address").Base(err)
|
||||||
}
|
}
|
||||||
request.Address = net.IPAddress(buffer.BytesFrom(-4))
|
request.Address = net.IPAddress(buffer.BytesFrom(-4))
|
||||||
case AddrTypeIPv6:
|
case AddrTypeIPv6:
|
||||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 16)); err != nil {
|
||||||
return nil, nil, newError("failed to read IPv6 address").Base(err)
|
return nil, nil, newError("failed to read IPv6 address").Base(err)
|
||||||
}
|
}
|
||||||
request.Address = net.IPAddress(buffer.BytesFrom(-16))
|
request.Address = net.IPAddress(buffer.BytesFrom(-16))
|
||||||
case AddrTypeDomain:
|
case AddrTypeDomain:
|
||||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 1)); err != nil {
|
||||||
return nil, nil, newError("failed to read domain lenth.").Base(err)
|
return nil, nil, newError("failed to read domain lenth.").Base(err)
|
||||||
}
|
}
|
||||||
domainLength := int(buffer.BytesFrom(-1)[0])
|
domainLength := int(buffer.BytesFrom(-1)[0])
|
||||||
err = buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength))
|
err = buffer.AppendSupplier(buf.ReadFullFrom(br, domainLength))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, newError("failed to read domain").Base(err)
|
return nil, nil, newError("failed to read domain").Base(err)
|
||||||
}
|
}
|
||||||
|
@ -96,7 +97,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
||||||
// Check address validity after OTA verification.
|
// Check address validity after OTA verification.
|
||||||
}
|
}
|
||||||
|
|
||||||
err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 2))
|
err = buffer.AppendSupplier(buf.ReadFullFrom(br, 2))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, newError("failed to read port").Base(err)
|
return nil, nil, newError("failed to read port").Base(err)
|
||||||
}
|
}
|
||||||
|
@ -106,7 +107,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
||||||
actualAuth := make([]byte, AuthSize)
|
actualAuth := make([]byte, AuthSize)
|
||||||
authenticator.Authenticate(buffer.Bytes())(actualAuth)
|
authenticator.Authenticate(buffer.Bytes())(actualAuth)
|
||||||
|
|
||||||
err := buffer.AppendSupplier(buf.ReadFullFrom(reader, AuthSize))
|
err := buffer.AppendSupplier(buf.ReadFullFrom(br, AuthSize))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, newError("Failed to read OTA").Base(err)
|
return nil, nil, newError("Failed to read OTA").Base(err)
|
||||||
}
|
}
|
||||||
|
@ -120,11 +121,13 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
||||||
return nil, nil, newError("invalid remote address.")
|
return nil, nil, newError("invalid remote address.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
br.SetBuffered(false)
|
||||||
|
|
||||||
var chunkReader buf.Reader
|
var chunkReader buf.Reader
|
||||||
if request.Option.Has(RequestOptionOneTimeAuth) {
|
if request.Option.Has(RequestOptionOneTimeAuth) {
|
||||||
chunkReader = NewChunkReader(reader, NewAuthenticator(ChunkKeyGenerator(iv)))
|
chunkReader = NewChunkReader(br, NewAuthenticator(ChunkKeyGenerator(iv)))
|
||||||
} else {
|
} else {
|
||||||
chunkReader = buf.NewReader(reader)
|
chunkReader = buf.NewReader(br)
|
||||||
}
|
}
|
||||||
|
|
||||||
return request, chunkReader, nil
|
return request, chunkReader, nil
|
||||||
|
@ -154,8 +157,6 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
|
||||||
return nil, newError("failed to create encoding stream").Base(err).AtError()
|
return nil, newError("failed to create encoding stream").Base(err).AtError()
|
||||||
}
|
}
|
||||||
|
|
||||||
writer = w.(io.Writer)
|
|
||||||
|
|
||||||
header := buf.NewLocal(512)
|
header := buf.NewLocal(512)
|
||||||
|
|
||||||
switch request.Address.Family() {
|
switch request.Address.Family() {
|
||||||
|
@ -185,16 +186,15 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
|
||||||
common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes())))
|
common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes())))
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = writer.Write(header.Bytes())
|
if err := w.WriteMultiBuffer(buf.NewMultiBufferValue(header)); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, newError("failed to write header").Base(err)
|
return nil, newError("failed to write header").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var chunkWriter buf.Writer
|
var chunkWriter buf.Writer
|
||||||
if request.Option.Has(RequestOptionOneTimeAuth) {
|
if request.Option.Has(RequestOptionOneTimeAuth) {
|
||||||
chunkWriter = NewChunkWriter(writer, NewAuthenticator(ChunkKeyGenerator(iv)))
|
chunkWriter = NewChunkWriter(w.(io.Writer), NewAuthenticator(ChunkKeyGenerator(iv)))
|
||||||
} else {
|
} else {
|
||||||
chunkWriter = buf.NewWriter(writer)
|
chunkWriter = w
|
||||||
}
|
}
|
||||||
|
|
||||||
return chunkWriter, nil
|
return chunkWriter, nil
|
||||||
|
|
Loading…
Reference in New Issue