diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 43dccf85..b064cc56 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -16,6 +16,7 @@ import ( "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common/signal/semaphore" "github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/transport/internet" @@ -44,18 +45,6 @@ var ( globalDialerAccess sync.Mutex ) -func destroyHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) { - globalDialerAccess.Lock() - defer globalDialerAccess.Unlock() - - if globalDialerMap == nil { - globalDialerMap = make(map[dialerConf]reusedClient) - } - - delete(globalDialerMap, dialerConf{dest, streamSettings}) - -} - func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) reusedClient { globalDialerAccess.Lock() defer globalDialerAccess.Unlock() @@ -77,7 +66,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in } dialContext := func(ctxInner context.Context) (net.Conn, error) { - conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings) if err != nil { return nil, err } @@ -85,7 +74,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in if gotlsConfig != nil { if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil { conn = tls.UClient(conn, gotlsConfig, fingerprint) - if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil { + if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil { return nil, err } } else { @@ -171,49 +160,73 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me var remoteAddr gonet.Addr var localAddr gonet.Addr + // this is done when the TCP/UDP connection to the server was established, + // and we can unblock the Dial function and print correct net addresses in + // logs + gotConn := done.New() - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - remoteAddr = connInfo.Conn.RemoteAddr() - localAddr = connInfo.Conn.LocalAddr() - }, - } + var downResponse io.ReadCloser + gotDownResponse := done.New() sessionIdUuid := uuid.New() sessionId := sessionIdUuid.String() - req, err := http.NewRequestWithContext( - httptrace.WithClientTrace(ctx, trace), - "GET", - requestURL.String()+"?session="+sessionId, - nil, - ) - if err != nil { - return nil, err - } + go func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + remoteAddr = connInfo.Conn.RemoteAddr() + localAddr = connInfo.Conn.LocalAddr() + gotConn.Close() + }, + } - req.Header = transportConfiguration.GetRequestHeader() - - downResponse, err := httpClient.download.Do(req) - if err != nil { - // workaround for various connection pool related issues, mostly around - // HTTP/1.1. if the http client ever fails to send a request, we simply - // delete it entirely. - // in HTTP/1.1, it was observed that pool connections would immediately - // fail with "context canceled" if the previous http response body was - // not explicitly BOTH drained and closed. at the same time, sometimes - // the draining itself takes forever and causes more problems. - // see also https://github.com/golang/go/issues/60240 - destroyHTTPClient(ctx, dest, streamSettings) - return nil, newError("failed to send download http request, destroying client").Base(err) - } + // in case we hit an error, we want to unblock this part + defer gotConn.Close() - if downResponse.StatusCode != 200 { - downResponse.Body.Close() - return nil, newError("invalid status code on download:", downResponse.Status) - } + req, err := http.NewRequestWithContext( + httptrace.WithClientTrace(context.WithoutCancel(ctx), trace), + "GET", + requestURL.String()+sessionId, + nil, + ) + if err != nil { + newError("failed to construct download http request").Base(err).WriteToLog() + gotDownResponse.Close() + return + } + + req.Header = transportConfiguration.GetRequestHeader() + + response, err := httpClient.download.Do(req) + gotConn.Close() + if err != nil { + newError("failed to send download http request").Base(err).WriteToLog() + gotDownResponse.Close() + return + } + + if response.StatusCode != 200 { + response.Body.Close() + newError("invalid status code on download:", response.Status).WriteToLog() + gotDownResponse.Close() + return + } + + // skip "ok" response + trashHeader := []byte{0, 0} + _, err = io.ReadFull(response.Body, trashHeader) + if err != nil { + response.Body.Close() + newError("failed to read initial response").Base(err).WriteToLog() + gotDownResponse.Close() + return + } - uploadUrl := requestURL.String() + "?session=" + sessionId + "&seq=" + downResponse = response.Body + gotDownResponse.Close() + }() + + uploadUrl := requestURL.String() + sessionId + "/" uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize)) @@ -266,7 +279,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me for i := 0; i < 5; i++ { uploadConn = httpClient.uploadRawPool.Get() if uploadConn == nil { - uploadConn, err = httpClient.dialUploadConn(ctx) + uploadConn, err = httpClient.dialUploadConn(context.WithoutCancel(ctx)) if err != nil { newError("failed to connect upload").Base(err).WriteToLog() uploadPipeReader.Interrupt() @@ -293,21 +306,27 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } }() - // skip "ok" response - trashHeader := []byte{0, 0} - _, err = io.ReadFull(downResponse.Body, trashHeader) - if err != nil { - downResponse.Body.Close() - return nil, newError("failed to read initial response") - } + // we want to block Dial until we know the remote address of the server, + // for logging purposes + <-gotConn.Wait() // necessary in order to send larger chunks in upload bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter) bufferedUploadPipeWriter.SetBuffered(false) + lazyDownload := &LazyReader{ + CreateReader: func() (io.ReadCloser, error) { + <-gotDownResponse.Wait() + if downResponse == nil { + return nil, newError("downResponse failed") + } + return downResponse, nil + }, + } + conn := splitConn{ writer: bufferedUploadPipeWriter, - reader: downResponse.Body, + reader: lazyDownload, remoteAddr: remoteAddr, localAddr: localAddr, } diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 1883bf23..412f686a 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -7,6 +7,7 @@ import ( gonet "net" "net/http" "strconv" + "strings" "sync" "time" @@ -28,20 +29,65 @@ type requestHandler struct { localAddr gonet.TCPAddr } +type httpSession struct { + uploadQueue *UploadQueue + // for as long as the GET request is not opened by the client, this will be + // open ("undone"), and the session may be expired within a certain TTL. + // after the client connects, this becomes "done" and the session lives as + // long as the GET request. + isFullyConnected *done.Instance +} + +func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) { + shouldReap := done.New() + go func() { + time.Sleep(30 * time.Second) + shouldReap.Close() + }() + + select { + case <-isFullyConnected.Wait(): + return + case <-shouldReap.Wait(): + h.sessions.Delete(sessionId) + } +} + +func (h *requestHandler) upsertSession(sessionId string) *httpSession { + currentSessionAny, ok := h.sessions.Load(sessionId) + if ok { + return currentSessionAny.(*httpSession) + } + + s := &httpSession{ + uploadQueue: NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())), + isFullyConnected: done.New(), + } + + h.sessions.Store(sessionId, s) + go h.maybeReapSession(s.isFullyConnected, sessionId) + return s +} + func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if len(h.host) > 0 && request.Host != h.host { newError("failed to validate host, request:", request.Host, ", config:", h.host).WriteToLog() writer.WriteHeader(http.StatusNotFound) return } - if request.URL.Path != h.path { + + if !strings.HasPrefix(request.URL.Path, h.path) { newError("failed to validate path, request:", request.URL.Path, ", config:", h.path).WriteToLog() writer.WriteHeader(http.StatusNotFound) return } - queryString := request.URL.Query() - sessionId := queryString.Get("session") + sessionId := "" + subpath := strings.Split(request.URL.Path[len(h.path):], "/") + if len(subpath) > 0 { + sessionId = subpath[0] + } + if sessionId == "" { newError("no sessionid on request:", request.URL.Path).WriteToLog() writer.WriteHeader(http.StatusBadRequest) @@ -60,15 +106,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } + currentSession := h.upsertSession(sessionId) + if request.Method == "POST" { - uploadQueue, ok := h.sessions.Load(sessionId) - if !ok { - newError("sessionid does not exist").WriteToLog() - writer.WriteHeader(http.StatusBadRequest) - return + seq := "" + if len(subpath) > 1 { + seq = subpath[1] } - seq := queryString.Get("seq") if seq == "" { newError("no seq on request:", request.URL.Path).WriteToLog() writer.WriteHeader(http.StatusBadRequest) @@ -89,7 +134,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req return } - err = uploadQueue.(*UploadQueue).Push(Packet{ + err = currentSession.uploadQueue.Push(Packet{ Payload: payload, Seq: seqInt, }) @@ -107,10 +152,9 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req panic("expected http.ResponseWriter to be an http.Flusher") } - uploadQueue := NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())) - - h.sessions.Store(sessionId, uploadQueue) - // the connection is finished, clean up map + // after GET is done, the connection is finished. disable automatic + // session reaping, and handle it in defer + currentSession.isFullyConnected.Close() defer h.sessions.Delete(sessionId) // magic header instructs nginx + apache to not buffer response body @@ -130,7 +174,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req downloadDone: downloadDone, responseFlusher: responseFlusher, }, - reader: uploadQueue, + reader: currentSession.uploadQueue, remoteAddr: remoteAddr, } diff --git a/transport/internet/splithttp/lazy_reader.go b/transport/internet/splithttp/lazy_reader.go new file mode 100644 index 00000000..5ae4ed55 --- /dev/null +++ b/transport/internet/splithttp/lazy_reader.go @@ -0,0 +1,57 @@ +package splithttp + +import ( + "io" + "sync" +) + +type LazyReader struct { + readerSync sync.Mutex + CreateReader func() (io.ReadCloser, error) + reader io.ReadCloser + readerError error +} + +func (r *LazyReader) getReader() (io.ReadCloser, error) { + r.readerSync.Lock() + defer r.readerSync.Unlock() + if r.reader != nil { + return r.reader, nil + } + + if r.readerError != nil { + return nil, r.readerError + } + + reader, err := r.CreateReader() + if err != nil { + r.readerError = err + return nil, err + } + + r.reader = reader + return reader, nil +} + +func (r *LazyReader) Read(b []byte) (int, error) { + reader, err := r.getReader() + if err != nil { + return 0, err + } + n, err := reader.Read(b) + return n, err +} + +func (r *LazyReader) Close() error { + r.readerSync.Lock() + defer r.readerSync.Unlock() + + var err error + if r.reader != nil { + err = r.reader.Close() + r.reader = nil + r.readerError = newError("closed reader") + } + + return err +}