Merge pull request #561 from fatedier/http

improve http vhost package
pull/564/head
fatedier 7 years ago committed by GitHub
commit 92046a7ca2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -635,7 +635,7 @@ func (cfg *StcpProxyConf) LoadFromFile(name string, section ini.Section) (err er
if tmpStr == "server" || tmpStr == "visitor" { if tmpStr == "server" || tmpStr == "visitor" {
cfg.Role = tmpStr cfg.Role = tmpStr
} else { } else {
cfg.Role = "server" return fmt.Errorf("Parse conf error: incorrect role [%s]", tmpStr)
} }
cfg.Sk = section["sk"] cfg.Sk = section["sk"]
@ -724,7 +724,7 @@ func (cfg *XtcpProxyConf) LoadFromFile(name string, section ini.Section) (err er
if tmpStr == "server" || tmpStr == "visitor" { if tmpStr == "server" || tmpStr == "visitor" {
cfg.Role = tmpStr cfg.Role = tmpStr
} else { } else {
cfg.Role = "server" return fmt.Errorf("Parse conf error: incorrect role [%s]", tmpStr)
} }
cfg.Sk = section["sk"] cfg.Sk = section["sk"]

@ -181,5 +181,5 @@ type NatHoleResp struct {
} }
type NatHoleSid struct { type NatHoleSid struct {
Sid string `json"sid"` Sid string `json:"sid"`
} }

@ -111,7 +111,7 @@ func (hp *HttpProxy) Handle(conn io.ReadWriteCloser) {
if realConn, ok := conn.(frpNet.Conn); ok { if realConn, ok := conn.(frpNet.Conn); ok {
wrapConn = realConn wrapConn = realConn
} else { } else {
wrapConn = frpNet.WrapReadWriteCloserToConn(conn) wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn)
} }
sc, rd := frpNet.NewShareConn(wrapConn) sc, rd := frpNet.NewShareConn(wrapConn)

@ -50,7 +50,7 @@ func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser) {
if realConn, ok := conn.(frpNet.Conn); ok { if realConn, ok := conn.(frpNet.Conn); ok {
wrapConn = realConn wrapConn = realConn
} else { } else {
wrapConn = frpNet.WrapReadWriteCloserToConn(conn) wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn)
} }
sp.Server.ServeConn(wrapConn) sp.Server.ServeConn(wrapConn)

@ -146,7 +146,7 @@ func (vm *VisitorManager) NewConn(name string, conn frpNet.Conn, timestamp int64
if useCompression { if useCompression {
rwc = frpIo.WithCompression(rwc) rwc = frpIo.WithCompression(rwc)
} }
err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc)) err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc, conn))
} else { } else {
err = fmt.Errorf("custom listener for [%s] doesn't exist", name) err = fmt.Errorf("custom listener for [%s] doesn't exist", name)
return return

@ -189,13 +189,16 @@ func (pxy *TcpProxy) Close() {
type HttpProxy struct { type HttpProxy struct {
BaseProxy BaseProxy
cfg *config.HttpProxyConf cfg *config.HttpProxyConf
closeFuncs []func()
} }
func (pxy *HttpProxy) Run() (err error) { func (pxy *HttpProxy) Run() (err error) {
routeConfig := &vhost.VhostRouteConfig{ routeConfig := vhost.VhostRouteConfig{
RewriteHost: pxy.cfg.HostHeaderRewrite, RewriteHost: pxy.cfg.HostHeaderRewrite,
Username: pxy.cfg.HttpUser, Username: pxy.cfg.HttpUser,
Password: pxy.cfg.HttpPwd, Password: pxy.cfg.HttpPwd,
CreateConnFn: pxy.GetRealConn,
} }
locations := pxy.cfg.Locations locations := pxy.cfg.Locations
@ -206,13 +209,16 @@ func (pxy *HttpProxy) Run() (err error) {
routeConfig.Domain = domain routeConfig.Domain = domain
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig) err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return err return err
} }
l.AddLogPrefix(pxy.name) tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation)
})
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
pxy.listeners = append(pxy.listeners, l)
} }
} }
@ -220,17 +226,18 @@ func (pxy *HttpProxy) Run() (err error) {
routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig) err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return err return err
} }
l.AddLogPrefix(pxy.name) tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location
pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation)
})
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
pxy.listeners = append(pxy.listeners, l)
} }
} }
pxy.startListenHandler(pxy, HandleUserTcpConnection)
return return
} }
@ -238,8 +245,33 @@ func (pxy *HttpProxy) GetConf() config.ProxyConf {
return pxy.cfg return pxy.cfg
} }
func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) {
tmpConn, errRet := pxy.GetWorkConnFromPool()
if errRet != nil {
err = errRet
return
}
var rwc io.ReadWriteCloser = tmpConn
if pxy.cfg.UseEncryption {
rwc, err = frpIo.WithEncryption(rwc, []byte(config.ServerCommonCfg.PrivilegeToken))
if err != nil {
pxy.Error("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.UseCompression {
rwc = frpIo.WithCompression(rwc)
}
workConn = frpNet.WrapReadWriteCloserToConn(rwc, tmpConn)
return
}
func (pxy *HttpProxy) Close() { func (pxy *HttpProxy) Close() {
pxy.BaseProxy.Close() pxy.BaseProxy.Close()
for _, closeFn := range pxy.closeFuncs {
closeFn()
}
} }
type HttpsProxy struct { type HttpsProxy struct {

@ -16,6 +16,8 @@ package server
import ( import (
"fmt" "fmt"
"net"
"net/http"
"time" "time"
"github.com/fatedier/frp/assets" "github.com/fatedier/frp/assets"
@ -44,12 +46,11 @@ type Service struct {
// Accept connections using kcp. // Accept connections using kcp.
kcpListener frpNet.Listener kcpListener frpNet.Listener
// For http proxies, route requests to different clients by hostname and other infomation.
VhostHttpMuxer *vhost.HttpMuxer
// For https proxies, route requests to different clients by hostname and other infomation. // For https proxies, route requests to different clients by hostname and other infomation.
VhostHttpsMuxer *vhost.HttpsMuxer VhostHttpsMuxer *vhost.HttpsMuxer
httpReverseProxy *vhost.HttpReverseProxy
// Manage all controllers. // Manage all controllers.
ctlManager *ControlManager ctlManager *ControlManager
@ -98,17 +99,21 @@ func NewService() (svr *Service, err error) {
// Create http vhost muxer. // Create http vhost muxer.
if cfg.VhostHttpPort > 0 { if cfg.VhostHttpPort > 0 {
var l frpNet.Listener rp := vhost.NewHttpReverseProxy()
l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpPort) svr.httpReverseProxy = rp
if err != nil {
err = fmt.Errorf("Create vhost http listener error, %v", err) address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
return server := &http.Server{
Addr: address,
Handler: rp,
} }
svr.VhostHttpMuxer, err = vhost.NewHttpMuxer(l, 30*time.Second) var l net.Listener
l, err = net.Listen("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("Create vhost httpMuxer error, %v", err) err = fmt.Errorf("Create vhost http listener error, %v", err)
return return
} }
go server.Serve(l)
log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
} }

@ -49,32 +49,50 @@ func WrapConn(c net.Conn) Conn {
type WrapReadWriteCloserConn struct { type WrapReadWriteCloserConn struct {
io.ReadWriteCloser io.ReadWriteCloser
log.Logger log.Logger
underConn net.Conn
} }
func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser) Conn { func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) Conn {
return &WrapReadWriteCloserConn{ return &WrapReadWriteCloserConn{
ReadWriteCloser: rwc, ReadWriteCloser: rwc,
Logger: log.NewPrefixLogger(""), Logger: log.NewPrefixLogger(""),
underConn: underConn,
} }
} }
func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr { func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr {
if conn.underConn != nil {
return conn.underConn.LocalAddr()
}
return (*net.TCPAddr)(nil) return (*net.TCPAddr)(nil)
} }
func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr { func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr {
if conn.underConn != nil {
return conn.underConn.RemoteAddr()
}
return (*net.TCPAddr)(nil) return (*net.TCPAddr)(nil)
} }
func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error { func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error {
if conn.underConn != nil {
return conn.underConn.SetDeadline(t)
}
return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
} }
func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error { func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error {
if conn.underConn != nil {
return conn.underConn.SetReadDeadline(t)
}
return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
} }
func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error { func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error {
if conn.underConn != nil {
return conn.underConn.SetWriteDeadline(t)
}
return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
} }

@ -19,7 +19,7 @@ import (
"strings" "strings"
) )
var version string = "0.14.0" var version string = "0.14.1"
func Full() string { func Full() string {
return version return version

@ -0,0 +1,186 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
"bytes"
"context"
"errors"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
frpLog "github.com/fatedier/frp/utils/log"
"github.com/fatedier/frp/utils/pool"
)
var (
responseHeaderTimeout = time.Duration(30) * time.Second
ErrRouterConfigConflict = errors.New("router config conflict")
ErrNoDomain = errors.New("no such domain")
)
func getHostFromAddr(addr string) (host string) {
strs := strings.Split(addr, ":")
if len(strs) > 1 {
host = strs[0]
} else {
host = addr
}
return
}
type HttpReverseProxy struct {
proxy *ReverseProxy
tr *http.Transport
vhostRouter *VhostRouters
cfgMu sync.RWMutex
}
func NewHttpReverseProxy() *HttpReverseProxy {
rp := &HttpReverseProxy{
vhostRouter: NewVhostRouters(),
}
proxy := &ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = "http"
url := req.Context().Value("url").(string)
host := getHostFromAddr(req.Context().Value("host").(string))
host = rp.GetRealHost(host, url)
if host != "" {
req.Host = host
}
req.URL.Host = req.Host
},
Transport: &http.Transport{
ResponseHeaderTimeout: responseHeaderTimeout,
DisableKeepAlives: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value("url").(string)
host := getHostFromAddr(ctx.Value("host").(string))
return rp.CreateConnection(host, url)
},
},
BufferPool: newWrapPool(),
ErrorLog: log.New(newWrapLogger(), "", 0),
}
rp.proxy = proxy
return rp
}
func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error {
rp.cfgMu.Lock()
defer rp.cfgMu.Unlock()
_, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location)
if ok {
return ErrRouterConfigConflict
} else {
rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
}
return nil
}
func (rp *HttpReverseProxy) UnRegister(domain string, location string) {
rp.cfgMu.Lock()
defer rp.cfgMu.Unlock()
rp.vhostRouter.Del(domain, location)
}
func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) {
vr, ok := rp.getVhost(domain, location)
if ok {
host = vr.payload.(*VhostRouteConfig).RewriteHost
}
return
}
func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) {
vr, ok := rp.getVhost(domain, location)
if ok {
fn := vr.payload.(*VhostRouteConfig).CreateConnFn
if fn != nil {
return fn()
}
}
return nil, ErrNoDomain
}
func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool {
vr, ok := rp.getVhost(domain, location)
if ok {
checkUser := vr.payload.(*VhostRouteConfig).Username
checkPasswd := vr.payload.(*VhostRouteConfig).Password
if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) {
return false
}
}
return true
}
func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) {
rp.cfgMu.RLock()
defer rp.cfgMu.RUnlock()
// first we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com
vr, ok = rp.vhostRouter.Get(domain, location)
if ok {
return
}
domainSplit := strings.Split(domain, ".")
if len(domainSplit) < 3 {
return vr, false
}
domainSplit[0] = "*"
domain = strings.Join(domainSplit, ".")
vr, ok = rp.vhostRouter.Get(domain, location)
return
}
func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
domain := getHostFromAddr(req.Host)
location := req.URL.Path
user, passwd, _ := req.BasicAuth()
if !rp.CheckAuth(domain, location, user, passwd) {
rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
rp.proxy.ServeHTTP(rw, req)
}
type wrapPool struct{}
func newWrapPool() *wrapPool { return &wrapPool{} }
func (p *wrapPool) Get() []byte { return pool.GetBuf(32 * 1024) }
func (p *wrapPool) Put(buf []byte) { pool.PutBuf(buf) }
type wrapLogger struct{}
func newWrapLogger() *wrapLogger { return &wrapLogger{} }
func (l *wrapLogger) Write(p []byte) (n int, err error) {
frpLog.Warn("%s", string(bytes.TrimRight(p, "\n")))
return len(p), nil
}

@ -0,0 +1,370 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP reverse proxy handler
package vhost
import (
"context"
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// onExitFlushLoop is a callback set by tests to detect the state of the
// flushLoop() goroutine.
var onExitFlushLoop func()
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
type ReverseProxy struct {
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger
// BufferPool optionally specifies a buffer pool to
// get byte slices for use by io.CopyBuffer when
// copying HTTP response bodies.
BufferPool BufferPool
// ModifyResponse is an optional function that
// modifies the Response from the backend.
// If it returns an error, the proxy returns a StatusBadGateway error.
ModifyResponse func(*http.Response) error
}
// A BufferPool is an interface for getting and returning temporary
// byte slices for use by io.CopyBuffer.
type BufferPool interface {
Get() []byte
Put([]byte)
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// NewSingleHostReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
// NewSingleHostReverseProxy does not rewrite the Host header.
// To rewrite Host headers, use ReverseProxy directly with a custom
// Director policy.
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}
return &ReverseProxy{Director: director}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func cloneHeader(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
ctx := req.Context()
if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
outreq.Header = cloneHeader(req.Header)
// Modify for frp
outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path))
outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host))
p.Director(outreq)
outreq.Close = false
// Remove hop-by-hop headers listed in the "Connection" header.
// See RFC 2616, section 14.10.
if c := outreq.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
outreq.Header.Del(f)
}
}
}
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
outreq.Header.Del(h)
}
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
outreq.Header.Set("X-Forwarded-For", clientIP)
}
res, err := transport.RoundTrip(outreq)
if err != nil {
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusNotFound)
rw.Write([]byte(NotFound))
return
}
// Remove hop-by-hop headers listed in the
// "Connection" header of the response.
if c := res.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
res.Header.Del(f)
}
}
}
for _, h := range hopHeaders {
res.Header.Del(h)
}
if p.ModifyResponse != nil {
if err := p.ModifyResponse(res); err != nil {
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusBadGateway)
return
}
}
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
p.copyResponse(rw, res.Body)
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
var buf []byte
if p.BufferPool != nil {
buf = p.BufferPool.Get()
}
p.copyBuffer(dst, src, buf)
if p.BufferPool != nil {
p.BufferPool.Put(buf)
}
}
func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
}
var written int64
for {
nr, rerr := src.Read(buf)
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
}
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
if werr != nil {
return written, werr
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if rerr != nil {
return written, rerr
}
}
}
func (p *ReverseProxy) logf(format string, args ...interface{}) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
mu sync.Mutex // protects Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-m.done:
if onExitFlushLoop != nil {
onExitFlushLoop()
}
return
case <-t.C:
m.mu.Lock()
m.dst.Flush()
m.mu.Unlock()
}
}
}
func (m *maxLatencyWriter) stop() { m.done <- true }

@ -14,7 +14,8 @@ type VhostRouters struct {
type VhostRouter struct { type VhostRouter struct {
domain string domain string
location string location string
listener *Listener
payload interface{}
} }
func NewVhostRouters() *VhostRouters { func NewVhostRouters() *VhostRouters {
@ -23,7 +24,7 @@ func NewVhostRouters() *VhostRouters {
} }
} }
func (r *VhostRouters) Add(domain, location string, l *Listener) { func (r *VhostRouters) Add(domain, location string, payload interface{}) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -35,7 +36,7 @@ func (r *VhostRouters) Add(domain, location string, l *Listener) {
vr := &VhostRouter{ vr := &VhostRouter{
domain: domain, domain: domain,
location: location, location: location,
listener: l, payload: payload,
} }
vrs = append(vrs, vr) vrs = append(vrs, vr)

@ -50,12 +50,16 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
return mux, nil return mux, nil
} }
type CreateConnFunc func() (frpNet.Conn, error)
type VhostRouteConfig struct { type VhostRouteConfig struct {
Domain string Domain string
Location string Location string
RewriteHost string RewriteHost string
Username string Username string
Password string Password string
CreateConnFn CreateConnFunc
} }
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil
@ -91,7 +95,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
// if not exist, then check the wildcard_domain such as *.example.com // if not exist, then check the wildcard_domain such as *.example.com
vr, found := v.registryRouter.Get(name, path) vr, found := v.registryRouter.Get(name, path)
if found { if found {
return vr.listener, true return vr.payload.(*Listener), true
} }
domainSplit := strings.Split(name, ".") domainSplit := strings.Split(name, ".")
@ -106,7 +110,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
return return
} }
return vr.listener, true return vr.payload.(*Listener), true
} }
func (v *VhostMuxer) run() { func (v *VhostMuxer) run() {

Loading…
Cancel
Save