package proxy
import (
"io"
"net"
"sync/atomic"
)
// Conn represents a single proxied TCP connection.
type Conn struct {
src , dst net . Conn
srcW , dstW countWriter
stopping int32
}
// NewConn returns a conn joining the two given net.Conn
func NewConn ( src , dst net . Conn ) * Conn {
return & Conn {
src : src ,
dst : dst ,
srcW : countWriter { w : src } ,
dstW : countWriter { w : dst } ,
stopping : 0 ,
}
}
// Close closes both connections.
func ( c * Conn ) Close ( ) error {
// Note that net.Conn.Close can be called multiple times and atomic store is
// idempotent so no need to ensure we only do this once.
//
// Also note that we don't wait for CopyBytes to return here since we are
// closing the conns which is the only externally visible sideeffect of that
// goroutine running and there should be no way for it to hang or leak once
// the conns are closed so we can save the extra coordination.
atomic . StoreInt32 ( & c . stopping , 1 )
c . src . Close ( )
c . dst . Close ( )
return nil
}
// CopyBytes will continuously copy bytes in both directions between src and dst
// until either connection is closed.
func ( c * Conn ) CopyBytes ( ) error {
defer c . Close ( )
go func ( ) {
// Need this since Copy is only guaranteed to stop when it's source reader
// (second arg) hits EOF or error but either conn might close first possibly
// causing this goroutine to exit but not the outer one. See
// TestConnSrcClosing which will fail if you comment the defer below.
defer c . Close ( )
io . Copy ( & c . dstW , c . src )
} ( )
_ , err := io . Copy ( & c . srcW , c . dst )
// Note that we don't wait for the other goroutine to finish because it either
// already has due to it's src conn closing, or it will once our defer fires
// and closes the source conn. No need for the extra coordination.
if atomic . LoadInt32 ( & c . stopping ) == 1 {
return nil
}
return err
}
// Stats returns number of bytes transmitted and recieved. Transmit means bytes
// written to dst, receive means bytes written to src.
func ( c * Conn ) Stats ( ) ( txBytes , rxBytes uint64 ) {
return c . srcW . Written ( ) , c . dstW . Written ( )
}
// countWriter is an io.Writer that counts the number of bytes being written
// before passing them through. We use it to gather metrics for bytes
// sent/received. Note that since we are always copying between a net.TCPConn
// and a tls.Conn, none of the optimisations using syscalls like splice and
// ReaderTo/WriterFrom can be used anyway and io.Copy falls back to a generic
// buffered read/write loop.
//
// We use atomic updates to synchronize reads and writes here. It's the cheapest
// uncontended option based on
// https://gist.github.com/banks/e76b40c0cc4b01503f0a0e4e0af231d5. Further
// optimization can be made when if/when identified as a real overhead.
type countWriter struct {
written uint64
w io . Writer
}
// Write implements io.Writer
func ( cw * countWriter ) Write ( p [ ] byte ) ( n int , err error ) {
n , err = cw . w . Write ( p )
atomic . AddUint64 ( & cw . written , uint64 ( n ) )
return
}
// Written returns how many bytes have been written to w.
func ( cw * countWriter ) Written ( ) uint64 {
return atomic . LoadUint64 ( & cw . written )
}