diff --git a/utils/vhost/http.go b/utils/vhost/http.go index 2e8208c..5ecce03 100644 --- a/utils/vhost/http.go +++ b/utils/vhost/http.go @@ -57,30 +57,31 @@ func GetHttpRequestInfo(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err } func NewHttpMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpMuxer, error) { - mux, err := NewVhostMuxer(listener, GetHttpRequestInfo, HttpAuthFunc, HttpHostNameRewrite, timeout) + mux, err := NewVhostMuxer(listener, GetHttpRequestInfo, HttpAuthFunc, ModifyHttpRequest, timeout) return &HttpMuxer{mux}, err } -func HttpHostNameRewrite(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) { +func ModifyHttpRequest(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) { sc, rd := frpNet.NewShareConn(c) var buff []byte - if buff, err = hostNameRewrite(rd, rewriteHost); err != nil { + remoteIP := strings.Split(c.RemoteAddr().String(), ":")[0] + if buff, err = hostNameRewrite(rd, rewriteHost, remoteIP); err != nil { return sc, err } err = sc.WriteBuff(buff) return sc, err } -func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) { +func hostNameRewrite(request io.Reader, rewriteHost string, remoteIP string) (_ []byte, err error) { buf := pool.GetBuf(1024) defer pool.PutBuf(buf) request.Read(buf) - retBuffer, err := parseRequest(buf, rewriteHost) + retBuffer, err := parseRequest(buf, rewriteHost, remoteIP) return retBuffer, err } -func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { +func parseRequest(org []byte, rewriteHost string, remoteIP string) (ret []byte, err error) { tp := bytes.NewBuffer(org) // First line: GET /index.html HTTP/1.0 var b []byte @@ -106,10 +107,19 @@ func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { // GET /index.html HTTP/1.1 // Host: www.google.com if req.URL.Host == "" { - changedBuf, err := changeHostName(tp, rewriteHost) + var changedBuf []byte + if rewriteHost != "" { + changedBuf, err = changeHostName(tp, rewriteHost) + } buf := new(bytes.Buffer) buf.Write(b) - buf.Write(changedBuf) + buf.WriteString(fmt.Sprintf("X-Forwarded-For: %s\n", remoteIP)) + buf.WriteString(fmt.Sprintf("X-Real-IP: %s\n", remoteIP)) + if len(changedBuf) == 0 { + tp.WriteTo(buf) + } else { + buf.Write(changedBuf) + } return buf.Bytes(), err } @@ -117,18 +127,21 @@ func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) { // GET http://www.google.com/index.html HTTP/1.1 // Host: doesntmatter // In this case, any Host line is ignored. - hostPort := strings.Split(req.URL.Host, ":") - if len(hostPort) == 1 { - req.URL.Host = rewriteHost - } else if len(hostPort) == 2 { - req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1]) + if rewriteHost != "" { + hostPort := strings.Split(req.URL.Host, ":") + if len(hostPort) == 1 { + req.URL.Host = rewriteHost + } else if len(hostPort) == 2 { + req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1]) + } } firstLine := req.Method + " " + req.URL.String() + " " + req.Proto buf := new(bytes.Buffer) buf.WriteString(firstLine) + buf.WriteString(fmt.Sprintf("X-Forwarded-For: %s\n", remoteIP)) + buf.WriteString(fmt.Sprintf("X-Real-IP: %s\n", remoteIP)) tp.WriteTo(buf) return buf.Bytes(), err - } // parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. @@ -164,7 +177,7 @@ func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error if portPos == -1 { hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost) } else { - hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:]) + hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[j+portPos+2:]) } retBuf.WriteString(hostHeader) peek = peek[i+1:] diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 21771d0..bb2b4ad 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -182,9 +182,10 @@ func (l *Listener) Accept() (frpNet.Conn, error) { return nil, fmt.Errorf("Listener closed") } - // if rewriteFunc is exist and rewriteHost is set + // if rewriteFunc is exist // rewrite http requests with a modified host header - if l.mux.rewriteFunc != nil && l.rewriteHost != "" { + // if l.rewriteHost is empty, nothing to do + if l.mux.rewriteFunc != nil { sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost) if err != nil { l.Warn("host header rewrite failed: %v", err)