Browse Source

Fix SplitHTTP Unix domain socket (#3577)

Co-authored-by: mmmray <142015632+mmmray@users.noreply.github.com>
pull/3593/head
hellokindle 4 months ago committed by GitHub
parent
commit
edae38c620
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      transport/internet/splithttp/client.go
  2. 15
      transport/internet/splithttp/hub.go
  3. 62
      transport/internet/splithttp/splithttp_test.go

2
transport/internet/splithttp/client.go

@ -117,10 +117,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
req, err := http.NewRequest("POST", url, payload)
req.ContentLength = contentLength
if err != nil {
return err
}
req.ContentLength = contentLength
req.Header = c.transportConfig.GetRequestHeader()
if c.isH2 || c.isH3 {

15
transport/internet/splithttp/hub.go

@ -314,14 +314,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
}
errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
l.server = http.Server{
Handler: h2cHandler,
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}
}
// tcp/unix (h1/h2)
@ -332,7 +324,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
}
}
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
l.listener = listener
l.server = http.Server{
Handler: h2cHandler,
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}
go func() {
if err := l.server.Serve(l.listener); err != nil {

62
transport/internet/splithttp/splithttp_test.go

@ -298,3 +298,65 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
t.Error("end: ", end, " start: ", start)
}
}
func Test_listenSHAndDial_Unix(t *testing.T) {
tempDir := t.TempDir()
tempSocket := tempDir + "/server.sock"
listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
Path: "/sh",
},
}, func(conn stat.Connection) {
go func(c stat.Connection) {
defer c.Close()
var b [1024]byte
c.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err := c.Read(b[:])
if err != nil {
return
}
common.Must2(c.Write([]byte("Response")))
}(conn)
})
common.Must(err)
ctx := context.Background()
streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "splithttp",
ProtocolSettings: &Config{
Host: "example.com",
Path: "sh",
},
}
conn, err := Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings)
common.Must(err)
_, err = conn.Write([]byte("Test connection 1"))
common.Must(err)
var b [1024]byte
fmt.Println("test2")
n, _ := conn.Read(b[:])
fmt.Println("string is", n)
if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}
common.Must(conn.Close())
conn, err = Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings)
common.Must(err)
_, err = conn.Write([]byte("Test connection 2"))
common.Must(err)
n, _ = conn.Read(b[:])
common.Must(err)
if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}
common.Must(conn.Close())
common.Must(listen.Close())
}

Loading…
Cancel
Save