From ebea255c74a906668e9efb59846c1081e8f74777 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 28 Jul 2018 15:03:40 +0200 Subject: [PATCH] eliminate partial writes --- common/buf/io.go | 11 +++++++++++ common/buf/writer.go | 5 +---- proxy/shadowsocks/ota.go | 2 +- proxy/shadowsocks/protocol.go | 4 ++-- proxy/socks/protocol.go | 13 +++++-------- transport/internet/headers/http/http.go | 2 +- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/common/buf/io.go b/common/buf/io.go index 37e739fa..aedfcbf2 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -47,6 +47,17 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier { } } +func WriteAllBytes(writer io.Writer, payload []byte) error { + for len(payload) > 0 { + n, err := writer.Write(payload) + if err != nil { + return err + } + payload = payload[n:] + } + return nil +} + // NewReader creates a new Reader. // The Reader instance doesn't take the ownership of reader. func NewReader(reader io.Reader) Reader { diff --git a/common/buf/writer.go b/common/buf/writer.go index 44918a29..b6b9e0e9 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -179,10 +179,7 @@ func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error { defer mb.Release() for _, b := range mb { - if b.IsEmpty() { - continue - } - if _, err := w.writer.Write(b.Bytes()); err != nil { + if err := WriteAllBytes(w.writer, b.Bytes()); err != nil { return err } } diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go index 7128ef3e..e3fd834d 100644 --- a/proxy/shadowsocks/ota.go +++ b/proxy/shadowsocks/ota.go @@ -118,7 +118,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { payloadLen, _ := mb.Read(w.buffer[2+AuthSize:]) serial.Uint16ToBytes(uint16(payloadLen), w.buffer[:0]) w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:]) - if _, err := w.writer.Write(w.buffer[:2+AuthSize+payloadLen]); err != nil { + if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil { return err } if mb.IsEmpty() { diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index b8764d45..836d0894 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -132,7 +132,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) common.Must2(rand.Read(iv)) - if _, err = writer.Write(iv); err != nil { + if err := buf.WriteAllBytes(writer, iv); err != nil { return nil, newError("failed to write IV") } } @@ -199,7 +199,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) common.Must2(rand.Read(iv)) - if _, err = writer.Write(iv); err != nil { + if err := buf.WriteAllBytes(writer, iv); err != nil { return nil, newError("failed to write IV.").Base(err) } } diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index e985a5fa..c7d1fe5b 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -234,8 +234,7 @@ func hasAuthMethod(expectedAuth byte, authCandidates []byte) bool { } func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte) error { - _, err := writer.Write([]byte{version, auth}) - return err + return buf.WriteAllBytes(writer, []byte{version, auth}) } func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { @@ -247,8 +246,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po return err } - _, err := writer.Write(buffer.Bytes()) - return err + return buf.WriteAllBytes(writer, buffer.Bytes()) } func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { @@ -258,8 +256,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po common.Must2(buffer.AppendBytes(0x00, errCode)) common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value()))) common.Must2(buffer.Write(address.IP())) - _, err := writer.Write(buffer.Bytes()) - return err + return buf.WriteAllBytes(writer, buffer.Bytes()) } func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) { @@ -365,7 +362,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i common.Must2(b.Write([]byte(account.Password))) } - if _, err := writer.Write(b.Bytes()); err != nil { + if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { return nil, err } @@ -400,7 +397,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i return nil, err } - if _, err := writer.Write(b.Bytes()); err != nil { + if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { return nil, err } diff --git a/transport/internet/headers/http/http.go b/transport/internet/headers/http/http.go index ac4df9c8..a8a21ae8 100644 --- a/transport/internet/headers/http/http.go +++ b/transport/internet/headers/http/http.go @@ -103,7 +103,7 @@ func (w *HeaderWriter) Write(writer io.Writer) error { if w.header == nil { return nil } - _, err := writer.Write(w.header.Bytes()) + err := buf.WriteAllBytes(writer, w.header.Bytes()) w.header.Release() w.header = nil return err