From 46358d466d0505728c610a6c064ca9239d42b39b Mon Sep 17 00:00:00 2001
From: fatedier <fatedier@gmail.com>
Date: Wed, 13 Dec 2017 04:28:58 +0800
Subject: [PATCH] support encryption and compression in new http reverser proxy

---
 models/plugin/http_proxy.go |  2 +-
 models/plugin/socks5.go     |  2 +-
 server/manager.go           |  2 +-
 server/proxy.go             | 26 ++++++++++++++++++++++++--
 utils/net/conn.go           | 20 +++++++++++++++++++-
 utils/vhost/newhttp.go      |  5 +++--
 6 files changed, 49 insertions(+), 8 deletions(-)

diff --git a/models/plugin/http_proxy.go b/models/plugin/http_proxy.go
index aaee5a16..f5fed6cb 100644
--- a/models/plugin/http_proxy.go
+++ b/models/plugin/http_proxy.go
@@ -111,7 +111,7 @@ func (hp *HttpProxy) Handle(conn io.ReadWriteCloser) {
 	if realConn, ok := conn.(frpNet.Conn); ok {
 		wrapConn = realConn
 	} else {
-		wrapConn = frpNet.WrapReadWriteCloserToConn(conn)
+		wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	}
 
 	sc, rd := frpNet.NewShareConn(wrapConn)
diff --git a/models/plugin/socks5.go b/models/plugin/socks5.go
index d3b82e12..b0f1bb24 100644
--- a/models/plugin/socks5.go
+++ b/models/plugin/socks5.go
@@ -50,7 +50,7 @@ func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser) {
 	if realConn, ok := conn.(frpNet.Conn); ok {
 		wrapConn = realConn
 	} else {
-		wrapConn = frpNet.WrapReadWriteCloserToConn(conn)
+		wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	}
 
 	sp.Server.ServeConn(wrapConn)
diff --git a/server/manager.go b/server/manager.go
index c023d187..ebc0928f 100644
--- a/server/manager.go
+++ b/server/manager.go
@@ -146,7 +146,7 @@ func (vm *VisitorManager) NewConn(name string, conn frpNet.Conn, timestamp int64
 		if useCompression {
 			rwc = frpIo.WithCompression(rwc)
 		}
-		err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc))
+		err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc, conn))
 	} else {
 		err = fmt.Errorf("custom listener for [%s] doesn't exist", name)
 		return
diff --git a/server/proxy.go b/server/proxy.go
index ed51a602..bd6234db 100644
--- a/server/proxy.go
+++ b/server/proxy.go
@@ -208,7 +208,7 @@ func (pxy *HttpProxy) Run() (err error) {
 		routeConfig.Domain = domain
 		for _, location := range locations {
 			routeConfig.Location = location
-			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetWorkConnFromPool)
+			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetRealConn)
 			if err != nil {
 				return err
 			}
@@ -225,7 +225,7 @@ func (pxy *HttpProxy) Run() (err error) {
 		routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
 		for _, location := range locations {
 			routeConfig.Location = location
-			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetWorkConnFromPool)
+			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetRealConn)
 			if err != nil {
 				return err
 			}
@@ -244,6 +244,28 @@ func (pxy *HttpProxy) GetConf() config.ProxyConf {
 	return pxy.cfg
 }
 
+func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) {
+	tmpConn, errRet := pxy.GetWorkConnFromPool()
+	if errRet != nil {
+		err = errRet
+		return
+	}
+
+	var rwc io.ReadWriteCloser = tmpConn
+	if pxy.cfg.UseEncryption {
+		rwc, err = frpIo.WithEncryption(rwc, []byte(config.ServerCommonCfg.PrivilegeToken))
+		if err != nil {
+			pxy.Error("create encryption stream error: %v", err)
+			return
+		}
+	}
+	if pxy.cfg.UseCompression {
+		rwc = frpIo.WithCompression(rwc)
+	}
+	workConn = frpNet.WrapReadWriteCloserToConn(rwc, tmpConn)
+	return
+}
+
 func (pxy *HttpProxy) Close() {
 	pxy.BaseProxy.Close()
 	for _, closeFn := range pxy.closeFuncs {
diff --git a/utils/net/conn.go b/utils/net/conn.go
index 392fb98f..c1f6f462 100644
--- a/utils/net/conn.go
+++ b/utils/net/conn.go
@@ -49,32 +49,50 @@ func WrapConn(c net.Conn) Conn {
 type WrapReadWriteCloserConn struct {
 	io.ReadWriteCloser
 	log.Logger
+
+	underConn net.Conn
 }
 
-func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser) Conn {
+func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) Conn {
 	return &WrapReadWriteCloserConn{
 		ReadWriteCloser: rwc,
 		Logger:          log.NewPrefixLogger(""),
+		underConn:       underConn,
 	}
 }
 
 func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr {
+	if conn.underConn != nil {
+		return conn.underConn.LocalAddr()
+	}
 	return (*net.TCPAddr)(nil)
 }
 
 func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr {
+	if conn.underConn != nil {
+		return conn.underConn.RemoteAddr()
+	}
 	return (*net.TCPAddr)(nil)
 }
 
 func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error {
+	if conn.underConn != nil {
+		return conn.underConn.SetDeadline(t)
+	}
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 
 func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error {
+	if conn.underConn != nil {
+		return conn.underConn.SetReadDeadline(t)
+	}
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 
 func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error {
+	if conn.underConn != nil {
+		return conn.underConn.SetWriteDeadline(t)
+	}
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 
diff --git a/utils/vhost/newhttp.go b/utils/vhost/newhttp.go
index c81f4441..55ba4199 100644
--- a/utils/vhost/newhttp.go
+++ b/utils/vhost/newhttp.go
@@ -15,6 +15,7 @@
 package vhost
 
 import (
+	"bytes"
 	"context"
 	"errors"
 	"log"
@@ -74,8 +75,8 @@ func NewHttpReverseProxy() *HttpReverseProxy {
 			host = rp.GetRealHost(host, url)
 			if host != "" {
 				req.Host = host
-				req.URL.Host = req.Host
 			}
+			req.URL.Host = req.Host
 		},
 		Transport: &http.Transport{
 			ResponseHeaderTimeout: responseHeaderTimeout,
@@ -172,6 +173,6 @@ type wrapLogger struct{}
 func newWrapLogger() *wrapLogger { return &wrapLogger{} }
 
 func (l *wrapLogger) Write(p []byte) (n int, err error) {
-	frpLog.Warn("%s", string(p))
+	frpLog.Warn("%s", string(bytes.TrimRight(p, "\n")))
 	return len(p), nil
 }