eliminate partial writes

pull/1524/head^2
Darien Raymond 6 years ago
parent 7a4b0fff07
commit ebea255c74
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -47,6 +47,17 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
}
}
func WriteAllBytes(writer io.Writer, payload []byte) error {
for len(payload) > 0 {
n, err := writer.Write(payload)
if err != nil {
return err
}
payload = payload[n:]
}
return nil
}
// NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader {

@ -179,10 +179,7 @@ func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if b.IsEmpty() {
continue
}
if _, err := w.writer.Write(b.Bytes()); err != nil {
if err := WriteAllBytes(w.writer, b.Bytes()); err != nil {
return err
}
}

@ -118,7 +118,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
payloadLen, _ := mb.Read(w.buffer[2+AuthSize:])
serial.Uint16ToBytes(uint16(payloadLen), w.buffer[:0])
w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:])
if _, err := w.writer.Write(w.buffer[:2+AuthSize+payloadLen]); err != nil {
if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil {
return err
}
if mb.IsEmpty() {

@ -132,7 +132,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
if account.Cipher.IVSize() > 0 {
iv = make([]byte, account.Cipher.IVSize())
common.Must2(rand.Read(iv))
if _, err = writer.Write(iv); err != nil {
if err := buf.WriteAllBytes(writer, iv); err != nil {
return nil, newError("failed to write IV")
}
}
@ -199,7 +199,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
if account.Cipher.IVSize() > 0 {
iv = make([]byte, account.Cipher.IVSize())
common.Must2(rand.Read(iv))
if _, err = writer.Write(iv); err != nil {
if err := buf.WriteAllBytes(writer, iv); err != nil {
return nil, newError("failed to write IV.").Base(err)
}
}

@ -234,8 +234,7 @@ func hasAuthMethod(expectedAuth byte, authCandidates []byte) bool {
}
func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte) error {
_, err := writer.Write([]byte{version, auth})
return err
return buf.WriteAllBytes(writer, []byte{version, auth})
}
func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
@ -247,8 +246,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po
return err
}
_, err := writer.Write(buffer.Bytes())
return err
return buf.WriteAllBytes(writer, buffer.Bytes())
}
func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
@ -258,8 +256,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
common.Must2(buffer.AppendBytes(0x00, errCode))
common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
common.Must2(buffer.Write(address.IP()))
_, err := writer.Write(buffer.Bytes())
return err
return buf.WriteAllBytes(writer, buffer.Bytes())
}
func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) {
@ -365,7 +362,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
common.Must2(b.Write([]byte(account.Password)))
}
if _, err := writer.Write(b.Bytes()); err != nil {
if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
return nil, err
}
@ -400,7 +397,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
return nil, err
}
if _, err := writer.Write(b.Bytes()); err != nil {
if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
return nil, err
}

@ -103,7 +103,7 @@ func (w *HeaderWriter) Write(writer io.Writer) error {
if w.header == nil {
return nil
}
_, err := writer.Write(w.header.Bytes())
err := buf.WriteAllBytes(writer, w.header.Bytes())
w.header.Release()
w.header = nil
return err

Loading…
Cancel
Save