eliminate partial writes

pull/1524/head^2
Darien Raymond 6 years ago
parent 7a4b0fff07
commit ebea255c74
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -47,6 +47,17 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
} }
} }
func WriteAllBytes(writer io.Writer, payload []byte) error {
for len(payload) > 0 {
n, err := writer.Write(payload)
if err != nil {
return err
}
payload = payload[n:]
}
return nil
}
// NewReader creates a new Reader. // NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader. // The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader { func NewReader(reader io.Reader) Reader {

@ -179,10 +179,7 @@ func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release() defer mb.Release()
for _, b := range mb { for _, b := range mb {
if b.IsEmpty() { if err := WriteAllBytes(w.writer, b.Bytes()); err != nil {
continue
}
if _, err := w.writer.Write(b.Bytes()); err != nil {
return err return err
} }
} }

@ -118,7 +118,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
payloadLen, _ := mb.Read(w.buffer[2+AuthSize:]) payloadLen, _ := mb.Read(w.buffer[2+AuthSize:])
serial.Uint16ToBytes(uint16(payloadLen), w.buffer[:0]) serial.Uint16ToBytes(uint16(payloadLen), w.buffer[:0])
w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:]) w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:])
if _, err := w.writer.Write(w.buffer[:2+AuthSize+payloadLen]); err != nil { if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil {
return err return err
} }
if mb.IsEmpty() { if mb.IsEmpty() {

@ -132,7 +132,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
if account.Cipher.IVSize() > 0 { if account.Cipher.IVSize() > 0 {
iv = make([]byte, account.Cipher.IVSize()) iv = make([]byte, account.Cipher.IVSize())
common.Must2(rand.Read(iv)) common.Must2(rand.Read(iv))
if _, err = writer.Write(iv); err != nil { if err := buf.WriteAllBytes(writer, iv); err != nil {
return nil, newError("failed to write IV") return nil, newError("failed to write IV")
} }
} }
@ -199,7 +199,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
if account.Cipher.IVSize() > 0 { if account.Cipher.IVSize() > 0 {
iv = make([]byte, account.Cipher.IVSize()) iv = make([]byte, account.Cipher.IVSize())
common.Must2(rand.Read(iv)) common.Must2(rand.Read(iv))
if _, err = writer.Write(iv); err != nil { if err := buf.WriteAllBytes(writer, iv); err != nil {
return nil, newError("failed to write IV.").Base(err) return nil, newError("failed to write IV.").Base(err)
} }
} }

@ -234,8 +234,7 @@ func hasAuthMethod(expectedAuth byte, authCandidates []byte) bool {
} }
func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte) error { func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte) error {
_, err := writer.Write([]byte{version, auth}) return buf.WriteAllBytes(writer, []byte{version, auth})
return err
} }
func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
@ -247,8 +246,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po
return err return err
} }
_, err := writer.Write(buffer.Bytes()) return buf.WriteAllBytes(writer, buffer.Bytes())
return err
} }
func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
@ -258,8 +256,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
common.Must2(buffer.AppendBytes(0x00, errCode)) common.Must2(buffer.AppendBytes(0x00, errCode))
common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value()))) common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
common.Must2(buffer.Write(address.IP())) common.Must2(buffer.Write(address.IP()))
_, err := writer.Write(buffer.Bytes()) return buf.WriteAllBytes(writer, buffer.Bytes())
return err
} }
func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) { func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) {
@ -365,7 +362,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
common.Must2(b.Write([]byte(account.Password))) common.Must2(b.Write([]byte(account.Password)))
} }
if _, err := writer.Write(b.Bytes()); err != nil { if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
return nil, err return nil, err
} }
@ -400,7 +397,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
return nil, err return nil, err
} }
if _, err := writer.Write(b.Bytes()); err != nil { if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
return nil, err return nil, err
} }

@ -103,7 +103,7 @@ func (w *HeaderWriter) Write(writer io.Writer) error {
if w.header == nil { if w.header == nil {
return nil return nil
} }
_, err := writer.Write(w.header.Bytes()) err := buf.WriteAllBytes(writer, w.header.Bytes())
w.header.Release() w.header.Release()
w.header = nil w.header = nil
return err return err

Loading…
Cancel
Save