clean udp writer

pull/298/merge
Darien Raymond 2017-04-21 14:51:09 +02:00
parent eda72624e2
commit 498c7dafdf
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
10 changed files with 48 additions and 55 deletions

View File

@ -128,6 +128,12 @@ func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
} }
} }
func NewSequentialWriter(writer io.Writer) Writer {
return &seqWriter{
writer: writer,
}
}
// ToBytesWriter converts a Writer to io.Writer // ToBytesWriter converts a Writer to io.Writer
func ToBytesWriter(writer Writer) io.Writer { func ToBytesWriter(writer Writer) io.Writer {
return &bytesToBufferWriter{ return &bytesToBufferWriter{

View File

@ -42,6 +42,25 @@ func (w *mergingWriter) Write(mb MultiBuffer) error {
return nil return nil
} }
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if b.IsEmpty() {
continue
}
if _, err := w.writer.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}
type bytesToBufferWriter struct { type bytesToBufferWriter struct {
writer Writer writer Writer
} }

View File

@ -4,7 +4,6 @@ package freedom
import ( import (
"context" "context"
"io"
"runtime" "runtime"
"time" "time"
@ -117,7 +116,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
if destination.Network == net.Network_TCP { if destination.Network == net.Network_TCP {
writer = buf.NewWriter(conn) writer = buf.NewWriter(conn)
} else { } else {
writer = &seqWriter{writer: conn} writer = buf.NewSequentialWriter(conn)
} }
if err := buf.Copy(timer, input, writer); err != nil { if err := buf.Copy(timer, input, writer); err != nil {
return newError("failed to process request").Base(err) return newError("failed to process request").Base(err)
@ -151,19 +150,3 @@ func init() {
return New(ctx, config.(*Config)) return New(ctx, config.(*Config))
})) }))
} }
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if _, err := w.writer.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}

View File

@ -135,10 +135,10 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
writer := &UDPWriter{ writer := buf.NewSequentialWriter(&UDPWriter{
Writer: conn, Writer: conn,
Request: request, Request: request,
} })
requestDone := signal.ExecuteAsync(func() error { requestDone := signal.ExecuteAsync(func() error {
if err := buf.Copy(timer, outboundRay.OutboundInput(), writer); err != nil { if err := buf.Copy(timer, outboundRay.OutboundInput(), writer); err != nil {

View File

@ -238,7 +238,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
return buf.NewWriter(crypto.NewCryptionWriter(stream, writer)), nil return buf.NewWriter(crypto.NewCryptionWriter(stream, writer)), nil
} }
func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf.Buffer, error) { func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
user := request.User user := request.User
rawAccount, err := user.GetTypedAccount() rawAccount, err := user.GetTypedAccount()
if err != nil { if err != nil {
@ -266,7 +266,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf
} }
buffer.AppendSupplier(serial.WriteUint16(uint16(request.Port))) buffer.AppendSupplier(serial.WriteUint16(uint16(request.Port)))
buffer.Append(payload.Bytes()) buffer.Append(payload)
if request.Option.Has(RequestOptionOneTimeAuth) { if request.Option.Has(RequestOptionOneTimeAuth) {
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
@ -382,23 +382,12 @@ type UDPWriter struct {
Request *protocol.RequestHeader Request *protocol.RequestHeader
} }
func (w *UDPWriter) Write(mb buf.MultiBuffer) error { func (w *UDPWriter) Write(payload []byte) (int, error) {
defer mb.Release() packet, err := EncodeUDPPacket(w.Request, payload)
for _, b := range mb {
if err := w.writeInternal(b); err != nil {
return err
}
}
return nil
}
func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error {
payload, err := EncodeUDPPacket(w.Request, buffer)
if err != nil { if err != nil {
return err return 0, err
} }
_, err = w.Writer.Write(payload.Bytes()) _, err = w.Writer.Write(packet.Bytes())
payload.Release() packet.Release()
return err return len(payload), err
} }

View File

@ -31,7 +31,7 @@ func TestUDPEncoding(t *testing.T) {
data := buf.NewLocal(256) data := buf.NewLocal(256)
data.AppendSupplier(serial.WriteString("test string")) data.AppendSupplier(serial.WriteString("test string"))
encodedData, err := EncodeUDPPacket(request, data) encodedData, err := EncodeUDPPacket(request, data.Bytes())
assert.Error(err).IsNil() assert.Error(err).IsNil()
decodedRequest, decodedData, err := DecodeUDPPacket(request.User, encodedData) decodedRequest, decodedData, err := DecodeUDPPacket(request.User, encodedData)
@ -88,7 +88,7 @@ func TestUDPReaderWriter(t *testing.T) {
}), }),
} }
cache := buf.New() cache := buf.New()
writer := &UDPWriter{ writer := buf.NewSequentialWriter(&UDPWriter{
Writer: cache, Writer: cache,
Request: &protocol.RequestHeader{ Request: &protocol.RequestHeader{
Version: Version, Version: Version,
@ -97,7 +97,7 @@ func TestUDPReaderWriter(t *testing.T) {
User: user, User: user,
Option: RequestOptionOneTimeAuth, Option: RequestOptionOneTimeAuth,
}, },
} })
reader := &UDPReader{ reader := &UDPReader{
Reader: cache, Reader: cache,

View File

@ -113,7 +113,7 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) { udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
defer payload.Release() defer payload.Release()
data, err := EncodeUDPPacket(request, payload) data, err := EncodeUDPPacket(request, payload.Bytes())
if err != nil { if err != nil {
log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning()) log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
return return

View File

@ -103,7 +103,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
} }
defer udpConn.Close() defer udpConn.Close()
requestFunc = func() error { requestFunc = func() error {
return buf.Copy(timer, ray.OutboundInput(), &UDPWriter{request: request, writer: udpConn}) return buf.Copy(timer, ray.OutboundInput(), buf.NewSequentialWriter(NewUDPWriter(request, udpConn)))
} }
responseFunc = func() error { responseFunc = func() error {
defer ray.OutboundOutput().Close() defer ray.OutboundOutput().Close()

View File

@ -369,17 +369,13 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter
} }
} }
func (w *UDPWriter) Write(mb buf.MultiBuffer) error { func (w *UDPWriter) Write(b []byte) (int, error) {
defer mb.Release() eb := EncodeUDPPacket(w.request, b)
for _, b := range mb {
eb := EncodeUDPPacket(w.request, b.Bytes())
defer eb.Release() defer eb.Release()
if _, err := w.writer.Write(eb.Bytes()); err != nil { if _, err := w.writer.Write(eb.Bytes()); err != nil {
return err return 0, err
} }
} return len(b), nil
return nil
} }
func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {

View File

@ -19,7 +19,7 @@ func TestUDPEncoding(t *testing.T) {
Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}), Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
Port: 1024, Port: 1024,
} }
writer := NewUDPWriter(request, b) writer := buf.NewSequentialWriter(NewUDPWriter(request, b))
content := []byte{'a'} content := []byte{'a'}
payload := buf.New() payload := buf.New()