diff --git a/src/frp/cmd/frpc/control.go b/src/frp/cmd/frpc/control.go index 9e8c3e62..ffdfe10d 100644 --- a/src/frp/cmd/frpc/control.go +++ b/src/frp/cmd/frpc/control.go @@ -144,6 +144,8 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) { UseGzip: cli.UseGzip, PrivilegeMode: cli.PrivilegeMode, ProxyType: cli.Type, + LocalIp: cli.LocalIp, + LocalPort: cli.LocalPort, Timestamp: nowTime, } if cli.PrivilegeMode { diff --git a/src/frp/cmd/frps/control.go b/src/frp/cmd/frps/control.go index 61660bfa..700a1466 100644 --- a/src/frp/cmd/frps/control.go +++ b/src/frp/cmd/frps/control.go @@ -276,6 +276,8 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) { // set infomations from frpc s.UseEncryption = req.UseEncryption s.UseGzip = req.UseGzip + s.ClientIp = req.LocalIp + s.ClientPort = req.LocalPort // start proxy and listen for user connections, no block err := s.Start(c) diff --git a/src/frp/models/config/config.go b/src/frp/models/config/config.go index 14200eb4..b18e6282 100644 --- a/src/frp/models/config/config.go +++ b/src/frp/models/config/config.go @@ -22,4 +22,7 @@ type BaseConf struct { UseGzip bool PrivilegeMode bool PrivilegeToken string + ClientIp string + ClientPort int64 + ServerPort int64 } diff --git a/src/frp/models/msg/msg.go b/src/frp/models/msg/msg.go index e89bce1c..4d06b7c1 100644 --- a/src/frp/models/msg/msg.go +++ b/src/frp/models/msg/msg.go @@ -26,6 +26,8 @@ type ControlReq struct { AuthKey string `json:"auth_key"` UseEncryption bool `json:"use_encryption"` UseGzip bool `json:"use_gzip"` + LocalIp string `json:"local_ip"` + LocalPort int64 `json:"local_port"` // configures used if privilege_mode is enabled PrivilegeMode bool `json:"privilege_mode"` diff --git a/src/frp/models/server/server.go b/src/frp/models/server/server.go index e69a2793..87087b49 100644 --- a/src/frp/models/server/server.go +++ b/src/frp/models/server/server.go @@ -64,6 +64,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) { p.BindAddr = BindAddr p.ListenPort = req.RemotePort p.CustomDomains = req.CustomDomains + p.ServerPort = VhostHttpPort return } @@ -113,7 +114,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { p.listeners = append(p.listeners, l) } else if p.Type == "http" { for _, domain := range p.CustomDomains { - l, err := VhostHttpMuxer.Listen(domain) + l, err := VhostHttpMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort) if err != nil { return err } @@ -121,7 +122,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { } } else if p.Type == "https" { for _, domain := range p.CustomDomains { - l, err := VhostHttpsMuxer.Listen(domain) + l, err := VhostHttpsMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort) if err != nil { return err } diff --git a/src/frp/utils/vhost/http.go b/src/frp/utils/vhost/http.go index 0f6aab5b..5e78de68 100644 --- a/src/frp/utils/vhost/http.go +++ b/src/frp/utils/vhost/http.go @@ -16,12 +16,17 @@ package vhost import ( "bufio" + "bytes" + "fmt" + "io" "net" "net/http" + "net/url" "strings" "time" "frp/utils/conn" + "frp/utils/log" ) type HttpMuxer struct { @@ -45,3 +50,112 @@ func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, e mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout) return &HttpMuxer{mux}, err } + +func HostNameRewrite(c *conn.Conn, clientHost string) (_ net.Conn, err error) { + log.Info("HostNameRewrite, clientHost: %s", clientHost) + sc, rd := newShareConn(c.TcpConn) + var buff []byte + if buff, err = hostNameRewrite(rd, clientHost); err != nil { + return sc, err + } + err = sc.WriteBuff(buff) + return sc, err +} + +func hostNameRewrite(request io.Reader, clientHost string) (_ []byte, err error) { + buffer := make([]byte, 1024) + request.Read(buffer) + log.Debug("before hostNameRewrite:\n %s", string(buffer)) + retBuffer, err := parseRequest(buffer, clientHost) + log.Debug("after hostNameRewrite:\n %s", string(retBuffer)) + return retBuffer, err +} + +func parseRequest(org []byte, clientHost string) (ret []byte, err error) { + tp := bytes.NewBuffer(org) + // First line: GET /index.html HTTP/1.0 + var b []byte + if b, err = tp.ReadBytes('\n'); err != nil { + return nil, err + } + req := new(http.Request) + //we invoked ReadRequest in GetHttpHostname before, so we ignore error + req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b)) + rawurl := req.RequestURI + //CONNECT www.google.com:443 HTTP/1.1 + justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") + if justAuthority { + rawurl = "http://" + rawurl + } + req.URL, _ = url.ParseRequestURI(rawurl) + if justAuthority { + // Strip the bogus "http://" back off. + req.URL.Scheme = "" + } + + // RFC2616: first case + // GET /index.html HTTP/1.1 + // Host: www.google.com + if req.URL.Host == "" { + changedBuf, err := changeHostName(tp, clientHost) + buf := new(bytes.Buffer) + buf.Write(b) + buf.Write(changedBuf) + return buf.Bytes(), err + } + + // RFC2616: second case + // GET http://www.google.com/index.html HTTP/1.1 + // Host: doesntmatter + // In this case, any Host line is ignored. + req.URL.Host = clientHost + firstLine := req.Method + " " + req.URL.String() + " " + req.Proto + buf := new(bytes.Buffer) + buf.WriteString(firstLine) + tp.WriteTo(buf) + return buf.Bytes(), err + +} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error) { + retBuf := new(bytes.Buffer) + + peek := buff.Bytes() + for len(peek) > 0 { + i := bytes.IndexByte(peek, '\n') + if i < 3 { + // Not present (-1) or found within the next few bytes, + // implying we're at the end ("\r\n\r\n" or "\n\n") + return nil, err + } + kv := peek[:i] + j := bytes.IndexByte(kv, ':') + if j < 0 { + return nil, fmt.Errorf("malformed MIME header line: " + string(kv)) + } + if strings.Contains(strings.ToLower(string(kv[:j])), "host") { + hostHeader := fmt.Sprintf("Host: %s\n", clientHost) + retBuf.WriteString(hostHeader) + peek = peek[i+1:] + break + } else { + retBuf.Write(peek[:i]) + retBuf.WriteByte('\n') + } + + peek = peek[i+1:] + } + retBuf.Write(peek) + return retBuf.Bytes(), err +} diff --git a/src/frp/utils/vhost/vhost.go b/src/frp/utils/vhost/vhost.go index ae672097..11038296 100644 --- a/src/frp/utils/vhost/vhost.go +++ b/src/frp/utils/vhost/vhost.go @@ -34,6 +34,10 @@ type VhostMuxer struct { vhostFunc muxFunc registryMap map[string]*Listener mutex sync.RWMutex + + //build map between custom_domains and client_domain + domainMap map[string]string + domainMutex sync.RWMutex } func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) { @@ -47,7 +51,7 @@ func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Dura return mux, nil } -func (v *VhostMuxer) Listen(name string) (l *Listener, err error) { +func (v *VhostMuxer) Listen(name, proxytype, clientIp string, clientPort, serverPort int64) (l *Listener, err error) { v.mutex.Lock() defer v.mutex.Unlock() if _, exist := v.registryMap[name]; exist { @@ -55,9 +59,13 @@ func (v *VhostMuxer) Listen(name string) (l *Listener, err error) { } l = &Listener{ - name: name, - mux: v, - accept: make(chan *conn.Conn), + name: name, + mux: v, + accept: make(chan *conn.Conn), + proxyType: proxytype, + clientIp: clientIp, + clientPort: clientPort, + serverPort: serverPort, } v.registryMap[name] = l return l, nil @@ -111,9 +119,13 @@ func (v *VhostMuxer) handle(c *conn.Conn) { } type Listener struct { - name string - mux *VhostMuxer // for closing VhostMuxer - accept chan *conn.Conn + name string + mux *VhostMuxer // for closing VhostMuxer + accept chan *conn.Conn + proxyType string //suppor http host rewrite + clientIp string + clientPort int64 + serverPort int64 } func (l *Listener) Accept() (*conn.Conn, error) { @@ -121,6 +133,20 @@ func (l *Listener) Accept() (*conn.Conn, error) { if !ok { return nil, fmt.Errorf("Listener closed") } + if net.ParseIP(l.clientIp) == nil && l.proxyType == "http" { + if (l.name != l.clientIp) || (l.serverPort != l.clientPort) { + clientHost := l.clientIp + if l.clientPort != 80 { + strPort := fmt.Sprintf(":%d", l.clientPort) + clientHost += strPort + } + retConn, err := HostNameRewrite(conn, clientHost) + if err != nil { + return nil, fmt.Errorf("http host rewrite failed") + } + conn.SetTcpConn(retConn) + } + } return conn, nil } @@ -166,3 +192,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) { sc.Unlock() return } + +func (sc *sharedConn) WriteBuff(buffer []byte) (err error) { + sc.buff.Reset() + _, err = sc.buff.Write(buffer) + return err +}