diff --git a/conf/frpc.ini b/conf/frpc.ini index d70760d9..4c6167b8 100644 --- a/conf/frpc.ini +++ b/conf/frpc.ini @@ -55,3 +55,4 @@ local_ip = 127.0.0.1 local_port = 80 use_gzip = true custom_domains = web03.yourdomain.com +host_header_rewrite = example.com diff --git a/src/frp/cmd/frpc/control.go b/src/frp/cmd/frpc/control.go index 2b2a7eec..f0f11cb4 100644 --- a/src/frp/cmd/frpc/control.go +++ b/src/frp/cmd/frpc/control.go @@ -138,14 +138,14 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) { nowTime := time.Now().Unix() req := &msg.ControlReq{ - Type: consts.NewCtlConn, - ProxyName: cli.Name, - UseEncryption: cli.UseEncryption, - UseGzip: cli.UseGzip, - PoolCount: cli.PoolCount, - PrivilegeMode: cli.PrivilegeMode, - ProxyType: cli.Type, - Timestamp: nowTime, + Type: consts.NewCtlConn, + ProxyName: cli.Name, + UseEncryption: cli.UseEncryption, + UseGzip: cli.UseGzip, + PrivilegeMode: cli.PrivilegeMode, + ProxyType: cli.Type, + HostHeaderRewrite: cli.HostHeaderRewrite, + Timestamp: nowTime, } if cli.PrivilegeMode { privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime)) diff --git a/src/frp/cmd/frps/control.go b/src/frp/cmd/frps/control.go index f51b905a..5fc0dbf2 100644 --- a/src/frp/cmd/frps/control.go +++ b/src/frp/cmd/frps/control.go @@ -276,6 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) { // set infomations from frpc s.UseEncryption = req.UseEncryption s.UseGzip = req.UseGzip + s.HostHeaderRewrite = req.HostHeaderRewrite if req.PoolCount > server.MaxPoolCount { s.PoolCount = server.MaxPoolCount } else if req.PoolCount < 0 { diff --git a/src/frp/models/client/config.go b/src/frp/models/client/config.go index 7269227d..3068ef17 100644 --- a/src/frp/models/client/config.go +++ b/src/frp/models/client/config.go @@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) { proxyClient.UseGzip = true } + if proxyClient.Type == "http" { + // host_header_rewrite + tmpStr, ok = section["host_header_rewrite"] + if ok { + proxyClient.HostHeaderRewrite = tmpStr + } + } + // privilege_mode proxyClient.PrivilegeMode = false tmpStr, ok = section["privilege_mode"] @@ -178,6 +186,7 @@ func LoadConf(confFile string) (err error) { return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name) } } else if proxyClient.Type == "http" { + // custom_domains domainStr, ok := section["custom_domains"] if ok { proxyClient.CustomDomains = strings.Split(domainStr, ",") @@ -191,6 +200,7 @@ func LoadConf(confFile string) (err error) { return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name) } } else if proxyClient.Type == "https" { + // custom_domains domainStr, ok := section["custom_domains"] if ok { proxyClient.CustomDomains = strings.Split(domainStr, ",") diff --git a/src/frp/models/config/config.go b/src/frp/models/config/config.go index f7cc5098..325dcb9b 100644 --- a/src/frp/models/config/config.go +++ b/src/frp/models/config/config.go @@ -15,12 +15,13 @@ package config type BaseConf struct { - Name string - AuthToken string - Type string - UseEncryption bool - UseGzip bool - PrivilegeMode bool - PrivilegeToken string - PoolCount int64 + Name string + AuthToken string + Type string + UseEncryption bool + UseGzip bool + PrivilegeMode bool + PrivilegeToken string + PoolCount int64 + HostHeaderRewrite string } diff --git a/src/frp/models/msg/msg.go b/src/frp/models/msg/msg.go index 55590dda..253511af 100644 --- a/src/frp/models/msg/msg.go +++ b/src/frp/models/msg/msg.go @@ -29,12 +29,13 @@ type ControlReq struct { PoolCount int64 `json:"pool_count"` // configures used if privilege_mode is enabled - PrivilegeMode bool `json:"privilege_mode"` - PrivilegeKey string `json:"privilege_key"` - ProxyType string `json:"proxy_type"` - RemotePort int64 `json:"remote_port"` - CustomDomains []string `json:"custom_domains, omitempty"` - Timestamp int64 `json:"timestamp"` + PrivilegeMode bool `json:"privilege_mode"` + PrivilegeKey string `json:"privilege_key"` + ProxyType string `json:"proxy_type"` + RemotePort int64 `json:"remote_port"` + CustomDomains []string `json:"custom_domains, omitempty"` + HostHeaderRewrite string `json:"host_header_rewrite"` + Timestamp int64 `json:"timestamp"` } type ControlRes struct { diff --git a/src/frp/models/msg/process.go b/src/frp/models/msg/process.go index 4c7783bd..43e34ab6 100644 --- a/src/frp/models/msg/process.go +++ b/src/frp/models/msg/process.go @@ -15,12 +15,10 @@ package msg import ( - "bufio" "bytes" "encoding/binary" "fmt" "io" - "net" "sync" "frp/models/config" @@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo defer wait.Done() // we don't care about errors here - pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord) + pipeEncrypt(from, to, conf, needRecord) } decryptPipe := func(to *conn.Conn, from *conn.Conn) { @@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo defer wait.Done() // we don't care about errors here - pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord) + pipeDecrypt(to, from, conf, needRecord) } wait.Add(2) @@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) { } // decrypt msg from reader, then write into writer -func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) buf := make([]byte, 5*1024+4) var left, res []byte - var cnt int + var cnt int = -1 // record var flowBytes int64 = 0 @@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) }() } - nreader := bufio.NewReader(r) for { // there may be more than 1 package in variable // and we read more bytes if unpkgMsg returns an error var newBuf []byte if cnt < 0 { - n, err := nreader.Read(buf) + n, err := r.Read(buf) if err != nil { return err } @@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } } - _, err = w.Write(res) + _, err = w.WriteBytes(res) if err != nil { return err } @@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } // recvive msg from reader, then encrypt msg into writer -func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) }() } - nreader := bufio.NewReader(r) buf := make([]byte, 5*1024) for { - n, err := nreader.Read(buf) + n, err := r.Read(buf) if err != nil { return err } @@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } res = pkgMsg(res) - _, err = w.Write(res) + _, err = w.WriteBytes(res) if err != nil { return err } diff --git a/src/frp/models/server/server.go b/src/frp/models/server/server.go index 99aff28e..4a84dca5 100644 --- a/src/frp/models/server/server.go +++ b/src/frp/models/server/server.go @@ -65,6 +65,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) { p.BindAddr = BindAddr p.ListenPort = req.RemotePort p.CustomDomains = req.CustomDomains + p.HostHeaderRewrite = req.HostHeaderRewrite return } @@ -81,7 +82,7 @@ func (p *ProxyServer) Init() { func (p *ProxyServer) Compare(p2 *ProxyServer) bool { if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type || - p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort { + p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite { return false } if len(p.CustomDomains) != len(p2.CustomDomains) { @@ -115,7 +116,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { p.listeners = append(p.listeners, l) } else if p.Type == "http" { for _, domain := range p.CustomDomains { - l, err := VhostHttpMuxer.Listen(domain) + l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite) if err != nil { return err } @@ -123,7 +124,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { } } else if p.Type == "https" { for _, domain := range p.CustomDomains { - l, err := VhostHttpsMuxer.Listen(domain) + l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite) if err != nil { return err } @@ -160,14 +161,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { return } - // start another goroutine for join two connections between frpc and user - go func() { + go func(userConn *conn.Conn) { workConn, err := p.getWorkConn() if err != nil { return } - userConn := c // message will be transferred to another without modifying // l means local, r means remote log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), @@ -176,7 +175,8 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { metric.OpenConnection(p.Name) needRecord := true go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) - }() + metric.OpenConnection(p.Name) + }(c) } }(listener) } diff --git a/src/frp/utils/conn/conn.go b/src/frp/utils/conn/conn.go index 0bfe9648..7375fd54 100644 --- a/src/frp/utils/conn/conn.go +++ b/src/frp/utils/conn/conn.go @@ -117,6 +117,16 @@ func ConnectServer(host string, port int64) (c *Conn, err error) { return c, nil } +// if the tcpConn is different with c.TcpConn +// you should call c.Close() first +func (c *Conn) SetTcpConn(tcpConn net.Conn) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.TcpConn = tcpConn + c.closeFlag = false + c.Reader = bufio.NewReader(c.TcpConn) +} + func (c *Conn) GetRemoteAddr() (addr string) { return c.TcpConn.RemoteAddr().String() } @@ -125,6 +135,11 @@ func (c *Conn) GetLocalAddr() (addr string) { return c.TcpConn.LocalAddr().String() } +func (c *Conn) Read(p []byte) (n int, err error) { + n, err = c.Reader.Read(p) + return +} + func (c *Conn) ReadLine() (buff string, err error) { buff, err = c.Reader.ReadString('\n') if err != nil { @@ -138,10 +153,14 @@ func (c *Conn) ReadLine() (buff string, err error) { return buff, err } +func (c *Conn) WriteBytes(content []byte) (n int, err error) { + n, err = c.TcpConn.Write(content) + return +} + func (c *Conn) Write(content string) (err error) { _, err = c.TcpConn.Write([]byte(content)) return err - } func (c *Conn) SetDeadline(t time.Time) error { diff --git a/src/frp/utils/vhost/http.go b/src/frp/utils/vhost/http.go index 0f6aab5b..4bc720cc 100644 --- a/src/frp/utils/vhost/http.go +++ b/src/frp/utils/vhost/http.go @@ -16,8 +16,12 @@ package vhost import ( "bufio" + "bytes" + "fmt" + "io" "net" "net/http" + "net/url" "strings" "time" @@ -42,6 +46,123 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) { } func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) { - mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout) + mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout) return &HttpMuxer{mux}, err } + +func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) { + sc, rd := newShareConn(c.TcpConn) + var buff []byte + if buff, err = hostNameRewrite(rd, rewriteHost); err != nil { + return sc, err + } + err = sc.WriteBuff(buff) + return sc, err +} + +func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) { + buffer := make([]byte, 1024) + request.Read(buffer) + retBuffer, err := parseRequest(buffer, rewriteHost) + return retBuffer, err +} + +func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { + tp := bytes.NewBuffer(org) + // First line: GET /index.html HTTP/1.0 + var b []byte + if b, err = tp.ReadBytes('\n'); err != nil { + return nil, err + } + req := new(http.Request) + // we invoked ReadRequest in GetHttpHostname before, so we ignore error + req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b)) + rawurl := req.RequestURI + // CONNECT www.google.com:443 HTTP/1.1 + justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") + if justAuthority { + rawurl = "http://" + rawurl + } + req.URL, _ = url.ParseRequestURI(rawurl) + if justAuthority { + // Strip the bogus "http://" back off. + req.URL.Scheme = "" + } + + // RFC2616: first case + // GET /index.html HTTP/1.1 + // Host: www.google.com + if req.URL.Host == "" { + changedBuf, err := changeHostName(tp, rewriteHost) + buf := new(bytes.Buffer) + buf.Write(b) + buf.Write(changedBuf) + return buf.Bytes(), err + } + + // RFC2616: second case + // GET http://www.google.com/index.html HTTP/1.1 + // Host: doesntmatter + // In this case, any Host line is ignored. + hostPort := strings.Split(req.URL.Host, ":") + if len(hostPort) == 1 { + req.URL.Host = rewriteHost + } else if len(hostPort) == 2 { + req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1]) + } + firstLine := req.Method + " " + req.URL.String() + " " + req.Proto + buf := new(bytes.Buffer) + buf.WriteString(firstLine) + tp.WriteTo(buf) + return buf.Bytes(), err + +} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) { + retBuf := new(bytes.Buffer) + + peek := buff.Bytes() + for len(peek) > 0 { + i := bytes.IndexByte(peek, '\n') + if i < 3 { + // Not present (-1) or found within the next few bytes, + // implying we're at the end ("\r\n\r\n" or "\n\n") + return nil, err + } + kv := peek[:i] + j := bytes.IndexByte(kv, ':') + if j < 0 { + return nil, fmt.Errorf("malformed MIME header line: " + string(kv)) + } + if strings.Contains(strings.ToLower(string(kv[:j])), "host") { + var hostHeader string + portPos := bytes.IndexByte(kv[j+1:], ':') + if portPos == -1 { + hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost) + } else { + hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:]) + } + retBuf.WriteString(hostHeader) + peek = peek[i+1:] + break + } else { + retBuf.Write(peek[:i]) + retBuf.WriteByte('\n') + } + + peek = peek[i+1:] + } + retBuf.Write(peek) + return retBuf.Bytes(), err +} diff --git a/src/frp/utils/vhost/https.go b/src/frp/utils/vhost/https.go index 2fd61c62..eedfab37 100644 --- a/src/frp/utils/vhost/https.go +++ b/src/frp/utils/vhost/https.go @@ -47,7 +47,7 @@ type HttpsMuxer struct { } func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) { - mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout) + mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout) return &HttpsMuxer{mux}, err } diff --git a/src/frp/utils/vhost/vhost.go b/src/frp/utils/vhost/vhost.go index ecf080d1..18c6d5dd 100644 --- a/src/frp/utils/vhost/vhost.go +++ b/src/frp/utils/vhost/vhost.go @@ -27,37 +27,42 @@ import ( ) type muxFunc func(*conn.Conn) (net.Conn, string, error) +type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error) type VhostMuxer struct { listener *conn.Listener timeout time.Duration vhostFunc muxFunc + rewriteFunc hostRewriteFunc registryMap map[string]*Listener mutex sync.RWMutex } -func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) { +func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { mux = &VhostMuxer{ listener: listener, timeout: timeout, vhostFunc: vhostFunc, + rewriteFunc: rewriteFunc, registryMap: make(map[string]*Listener), } go mux.run() return mux, nil } -func (v *VhostMuxer) Listen(name string) (l *Listener, err error) { +// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost +func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) { v.mutex.Lock() defer v.mutex.Unlock() if _, exist := v.registryMap[name]; exist { - return nil, fmt.Errorf("name %s is already bound", name) + return nil, fmt.Errorf("domain name %s is already bound", name) } l = &Listener{ - name: name, - mux: v, - accept: make(chan *conn.Conn), + name: name, + rewriteHost: rewriteHost, + mux: v, + accept: make(chan *conn.Conn), } v.registryMap[name] = l return l, nil @@ -105,15 +110,16 @@ func (v *VhostMuxer) handle(c *conn.Conn) { if err = sConn.SetDeadline(time.Time{}); err != nil { return } - c.TcpConn = sConn + c.SetTcpConn(sConn) l.accept <- c } type Listener struct { - name string - mux *VhostMuxer // for closing VhostMuxer - accept chan *conn.Conn + name string + rewriteHost string + mux *VhostMuxer // for closing VhostMuxer + accept chan *conn.Conn } func (l *Listener) Accept() (*conn.Conn, error) { @@ -121,6 +127,17 @@ func (l *Listener) Accept() (*conn.Conn, error) { if !ok { return nil, fmt.Errorf("Listener closed") } + + // if rewriteFunc is exist and rewriteHost is set + // rewrite http requests with a modified host header + if l.mux.rewriteFunc != nil && l.rewriteHost != "" { + fmt.Printf("host rewrite: %s\n", l.rewriteHost) + sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost) + if err != nil { + return nil, fmt.Errorf("http host header rewrite failed") + } + conn.SetTcpConn(sConn) + } return conn, nil } @@ -140,6 +157,7 @@ type sharedConn struct { buff *bytes.Buffer } +// the bytes you read in io.Reader, will be reserved in sharedConn func newShareConn(conn net.Conn) (*sharedConn, io.Reader) { sc := &sharedConn{ Conn: conn, @@ -166,3 +184,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) { sc.Unlock() return } + +func (sc *sharedConn) WriteBuff(buffer []byte) (err error) { + sc.buff.Reset() + _, err = sc.buff.Write(buffer) + return err +}