diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index 50274548..9dc4a178 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -7,6 +7,7 @@ import ( "time" "github.com/golang/protobuf/proto" + "v2ray.com/core/common/buf" "v2ray.com/core/common/mux" "v2ray.com/core/common/net" "v2ray.com/core/common/session" @@ -119,6 +120,13 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo tag: tag, } + // Initialize the connection by sending a Keepalive frame + keepalive := buf.New() + mux.FrameMetadata{SessionStatus: mux.SessionStatusKeepAlive}.WriteTo(keepalive) + err = link.Writer.WriteMultiBuffer(buf.MultiBuffer{keepalive}) + if err != nil { + return nil, err + } worker, err := mux.NewServerWorker(context.Background(), w, link) if err != nil { return nil, err diff --git a/proxy/http/client.go b/proxy/http/client.go index 51dca746..318ff332 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -3,10 +3,14 @@ package http import ( + "bufio" "context" "encoding/base64" "io" + "net/http" "strings" + "sync" + "time" "v2ray.com/core" "v2ray.com/core/common" @@ -22,6 +26,7 @@ import ( "v2ray.com/core/transport/internet" ) +// Client is a inbound handler for HTTP protocol type Client struct { serverPicker protocol.ServerPicker policyManager policy.Manager @@ -90,9 +95,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter p = c.policyManager.ForLevel(user.Level) } - if err := setUpHttpTunnel(conn, conn, &destination, user); err != nil { - return err - } + conn = setUpHTTPTunnel(conn, &destination, user) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) @@ -103,7 +106,15 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } responseFunc := func() error { defer timer.SetTimeout(p.Timeouts.UplinkOnly) - return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + bc := bufio.NewReader(conn) + resp, err := http.ReadResponse(bc, nil) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return newError(resp.Status) + } + return buf.Copy(buf.NewReader(bc), link.Writer, buf.UpdateActivity(timer)) } var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer)) @@ -114,8 +125,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return nil } -// setUpHttpTunnel will create a socket tunnel via HTTP CONNECT method -func setUpHttpTunnel(reader io.Reader, writer io.Writer, destination *net.Destination, user *protocol.MemoryUser) error { +// setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method +func setUpHTTPTunnel(conn internet.Connection, destination *net.Destination, user *protocol.MemoryUser) *tunConn { var headers []string destNetAddr := destination.NetAddr() headers = append(headers, "CONNECT "+destNetAddr+" HTTP/1.1") @@ -129,16 +140,62 @@ func setUpHttpTunnel(reader io.Reader, writer io.Writer, destination *net.Destin b := buf.New() b.WriteString(strings.Join(headers, "\r\n") + "\r\n\r\n") - if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { - return err - } + return newTunConn(conn, b, 5 * time.Millisecond) +} - b.Clear() - if _, err := b.ReadFrom(reader); err != nil { - return err - } +// tunConn is a connection that writes header before content, +// the header will be written during the next Write call or after +// specified delay. +type tunConn struct { + internet.Connection + header *buf.Buffer + once sync.Once + timer *time.Timer +} - return nil +func newTunConn(conn internet.Connection, header *buf.Buffer, delay time.Duration) *tunConn { + tc := &tunConn{ + Connection: conn, + header: header, + } + if delay > 0 { + tc.timer = time.AfterFunc(delay, func() { + tc.Write([]byte{}) + }) + } + return tc +} + +func (c *tunConn) Write(b []byte) (n int, err error) { + // fallback to normal write if header is sent + if c.header == nil { + return c.Connection.Write(b) + } + // Prevent timer and writer race condition + c.once.Do(func() { + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + lenheader := c.header.Len() + // Concate header and b + common.Must2(c.header.Write(b)) + // Write buffer + var nc int64 + nc, err = io.Copy(c.Connection, c.header) + c.header.Release() + c.header = nil + n = int(nc) - int(lenheader) + if n < 0 { n = 0 } + b = b[n:] + }) + // Write Trailing bytes + if len(b) > 0 && err == nil { + var nw int + nw, err = c.Connection.Write(b) + n += nw + } + return n, err } func init() {