mirror of https://github.com/XTLS/Xray-core
146 lines
2.8 KiB
Go
146 lines
2.8 KiB
Go
package encoding
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
|
|
"github.com/xtls/xray-core/common/buf"
|
|
"github.com/xtls/xray-core/common/errors"
|
|
xnet "github.com/xtls/xray-core/common/net"
|
|
"github.com/xtls/xray-core/common/net/cnc"
|
|
"github.com/xtls/xray-core/common/signal/done"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/peer"
|
|
)
|
|
|
|
type MultiHunkConn interface {
|
|
Context() context.Context
|
|
Send(*MultiHunk) error
|
|
Recv() (*MultiHunk, error)
|
|
SendMsg(m interface{}) error
|
|
RecvMsg(m interface{}) error
|
|
}
|
|
|
|
type MultiHunkReaderWriter struct {
|
|
hc MultiHunkConn
|
|
cancel context.CancelFunc
|
|
done *done.Instance
|
|
|
|
buf [][]byte
|
|
}
|
|
|
|
func NewMultiHunkReadWriter(hc MultiHunkConn, cancel context.CancelFunc) *MultiHunkReaderWriter {
|
|
return &MultiHunkReaderWriter{hc, cancel, done.New(), nil}
|
|
}
|
|
|
|
func NewMultiHunkConn(hc MultiHunkConn, cancel context.CancelFunc) net.Conn {
|
|
var rAddr net.Addr
|
|
pr, ok := peer.FromContext(hc.Context())
|
|
if ok {
|
|
rAddr = pr.Addr
|
|
} else {
|
|
rAddr = &net.TCPAddr{
|
|
IP: []byte{0, 0, 0, 0},
|
|
Port: 0,
|
|
}
|
|
}
|
|
|
|
md, ok := metadata.FromIncomingContext(hc.Context())
|
|
if ok {
|
|
header := md.Get("x-real-ip")
|
|
if len(header) > 0 {
|
|
realip := xnet.ParseAddress(header[0])
|
|
if realip.Family().IsIP() {
|
|
rAddr = &net.TCPAddr{
|
|
IP: realip.IP(),
|
|
Port: 0,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
wrc := NewMultiHunkReadWriter(hc, cancel)
|
|
return cnc.NewConnection(
|
|
cnc.ConnectionInputMulti(wrc),
|
|
cnc.ConnectionOutputMulti(wrc),
|
|
cnc.ConnectionOnClose(wrc),
|
|
cnc.ConnectionRemoteAddr(rAddr),
|
|
)
|
|
}
|
|
|
|
func (h *MultiHunkReaderWriter) forceFetch() error {
|
|
hunk, err := h.hc.Recv()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return err
|
|
}
|
|
|
|
return errors.New("failed to fetch hunk from gRPC tunnel").Base(err)
|
|
}
|
|
|
|
h.buf = hunk.Data
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *MultiHunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
|
if h.done.Done() {
|
|
return nil, io.EOF
|
|
}
|
|
|
|
if err := h.forceFetch(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mb := make(buf.MultiBuffer, 0, len(h.buf))
|
|
for _, b := range h.buf {
|
|
if len(b) == 0 {
|
|
continue
|
|
}
|
|
|
|
if cap(b) >= buf.Size {
|
|
mb = append(mb, buf.NewExisted(b))
|
|
} else {
|
|
nb := buf.New()
|
|
nb.Extend(int32(len(b)))
|
|
copy(nb.Bytes(), b)
|
|
|
|
mb = append(mb, nb)
|
|
}
|
|
|
|
}
|
|
return mb, nil
|
|
}
|
|
|
|
func (h *MultiHunkReaderWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
|
defer buf.ReleaseMulti(mb)
|
|
if h.done.Done() {
|
|
return io.ErrClosedPipe
|
|
}
|
|
|
|
hunks := make([][]byte, 0, len(mb))
|
|
|
|
for _, b := range mb {
|
|
if b.Len() > 0 {
|
|
hunks = append(hunks, b.Bytes())
|
|
}
|
|
}
|
|
|
|
err := h.hc.Send(&MultiHunk{Data: hunks})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *MultiHunkReaderWriter) Close() error {
|
|
if h.cancel != nil {
|
|
h.cancel()
|
|
}
|
|
if sc, match := h.hc.(StreamCloser); match {
|
|
return sc.CloseSend()
|
|
}
|
|
|
|
return h.done.Close()
|
|
}
|