diff --git a/client/client.go b/client/client.go index 64910f6..377bb62 100755 --- a/client/client.go +++ b/client/client.go @@ -2,6 +2,11 @@ package client import ( "bufio" + "net" + "net/http" + "strconv" + "time" + "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/lib/config" "github.com/cnlh/nps/lib/conn" @@ -9,10 +14,6 @@ import ( "github.com/cnlh/nps/lib/mux" "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "github.com/cnlh/nps/vender/github.com/xtaci/kcp" - "net" - "net/http" - "strconv" - "time" ) type TRPClient struct { @@ -159,7 +160,7 @@ func (s *TRPClient) handleChan(src net.Conn) { lk.Host = common.FormatAddress(lk.Host) //if Conn type is http, read the request and log if lk.ConnType == "http" { - if targetConn, err := net.Dial(common.CONN_TCP, lk.Host); err != nil { + if targetConn, err := net.DialTimeout(common.CONN_TCP, lk.Host, lk.Option.Timeout); err != nil { logs.Warn("connect to %s error %s", lk.Host, err.Error()) src.Close() } else { @@ -183,7 +184,7 @@ func (s *TRPClient) handleChan(src net.Conn) { return } //connect to target if conn type is tcp or udp - if targetConn, err := net.Dial(lk.ConnType, lk.Host); err != nil { + if targetConn, err := net.DialTimeout(lk.ConnType, lk.Host, lk.Option.Timeout); err != nil { logs.Warn("connect to %s error %s", lk.Host, err.Error()) src.Close() } else { diff --git a/lib/conn/link.go b/lib/conn/link.go index 974827d..653894e 100644 --- a/lib/conn/link.go +++ b/lib/conn/link.go @@ -1,5 +1,7 @@ package conn +import "time" + type Secret struct { Password string Conn *Conn @@ -19,9 +21,20 @@ type Link struct { Compress bool LocalProxy bool RemoteAddr string + Option Options } -func NewLink(connType string, host string, crypt bool, compress bool, remoteAddr string, localProxy bool) *Link { +type Option func(*Options) + +type Options struct { + Timeout time.Duration +} + +var defaultTimeOut = time.Second * 5 + +func NewLink(connType string, host string, crypt bool, compress bool, remoteAddr string, localProxy bool, opts ...Option) *Link { + options := newOptions(opts...) + return &Link{ RemoteAddr: remoteAddr, ConnType: connType, @@ -29,5 +42,22 @@ func NewLink(connType string, host string, crypt bool, compress bool, remoteAddr Crypt: crypt, Compress: compress, LocalProxy: localProxy, + Option: options, + } +} + +func newOptions(opts ...Option) Options { + opt := Options{ + Timeout: defaultTimeOut, + } + for _, o := range opts { + o(&opt) + } + return opt +} + +func LinkTimeout(t time.Duration) Option { + return func(opt *Options) { + opt.Timeout = t } } diff --git a/lib/mux/conn.go b/lib/mux/conn.go index 9e66577..2c9abac 100644 --- a/lib/mux/conn.go +++ b/lib/mux/conn.go @@ -93,7 +93,8 @@ func (s *conn) Write(buf []byte) (int, error) { if s.isClose { return 0, errors.New("the conn has closed") } - ch := make(chan struct{}) + ch := make(chan error) + var err error go s.write(buf, ch) if t := s.writeTimeOut.Sub(time.Now()); t > 0 { timer := time.NewTimer(t) @@ -101,17 +102,20 @@ func (s *conn) Write(buf []byte) (int, error) { select { case <-timer.C: return 0, errors.New("write timeout") - case <-ch: + case err = <-ch: } } else { - <-ch + err = <-ch } if s.isClose { return 0, io.EOF } + if err != nil { + return 0, err + } return len(buf), nil } -func (s *conn) write(buf []byte, ch chan struct{}) { +func (s *conn) write(buf []byte, ch chan error) { start := 0 l := len(buf) for { @@ -120,14 +124,18 @@ func (s *conn) write(buf []byte, ch chan struct{}) { } s.hasWrite++ if l-start > pool.PoolSizeCopy { - s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:start+pool.PoolSizeCopy]) + if err := s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:start+pool.PoolSizeCopy]); err != nil { + ch <- err + } start += pool.PoolSizeCopy } else { - s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:l]) + if err := s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:l]); err != nil { + ch <- err + } break } } - ch <- struct{}{} + ch <- nil } func (s *conn) Close() error {