mirror of https://github.com/fatedier/frp
Fix conflicts in fatedier/connection_pool with dev
Conflicts: src/frp/cmd/frpc/control.go src/frp/cmd/frps/control.go src/frp/models/config/config.go src/frp/models/server/server.gopull/55/head
commit
fd3c97a0e9
|
@ -55,3 +55,4 @@ local_ip = 127.0.0.1
|
||||||
local_port = 80
|
local_port = 80
|
||||||
use_gzip = true
|
use_gzip = true
|
||||||
custom_domains = web03.yourdomain.com
|
custom_domains = web03.yourdomain.com
|
||||||
|
host_header_rewrite = example.com
|
||||||
|
|
|
@ -138,14 +138,14 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
|
||||||
|
|
||||||
nowTime := time.Now().Unix()
|
nowTime := time.Now().Unix()
|
||||||
req := &msg.ControlReq{
|
req := &msg.ControlReq{
|
||||||
Type: consts.NewCtlConn,
|
Type: consts.NewCtlConn,
|
||||||
ProxyName: cli.Name,
|
ProxyName: cli.Name,
|
||||||
UseEncryption: cli.UseEncryption,
|
UseEncryption: cli.UseEncryption,
|
||||||
UseGzip: cli.UseGzip,
|
UseGzip: cli.UseGzip,
|
||||||
PoolCount: cli.PoolCount,
|
PrivilegeMode: cli.PrivilegeMode,
|
||||||
PrivilegeMode: cli.PrivilegeMode,
|
ProxyType: cli.Type,
|
||||||
ProxyType: cli.Type,
|
HostHeaderRewrite: cli.HostHeaderRewrite,
|
||||||
Timestamp: nowTime,
|
Timestamp: nowTime,
|
||||||
}
|
}
|
||||||
if cli.PrivilegeMode {
|
if cli.PrivilegeMode {
|
||||||
privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime))
|
privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime))
|
||||||
|
|
|
@ -276,6 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
|
||||||
// set infomations from frpc
|
// set infomations from frpc
|
||||||
s.UseEncryption = req.UseEncryption
|
s.UseEncryption = req.UseEncryption
|
||||||
s.UseGzip = req.UseGzip
|
s.UseGzip = req.UseGzip
|
||||||
|
s.HostHeaderRewrite = req.HostHeaderRewrite
|
||||||
if req.PoolCount > server.MaxPoolCount {
|
if req.PoolCount > server.MaxPoolCount {
|
||||||
s.PoolCount = server.MaxPoolCount
|
s.PoolCount = server.MaxPoolCount
|
||||||
} else if req.PoolCount < 0 {
|
} else if req.PoolCount < 0 {
|
||||||
|
|
|
@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) {
|
||||||
proxyClient.UseGzip = true
|
proxyClient.UseGzip = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if proxyClient.Type == "http" {
|
||||||
|
// host_header_rewrite
|
||||||
|
tmpStr, ok = section["host_header_rewrite"]
|
||||||
|
if ok {
|
||||||
|
proxyClient.HostHeaderRewrite = tmpStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// privilege_mode
|
// privilege_mode
|
||||||
proxyClient.PrivilegeMode = false
|
proxyClient.PrivilegeMode = false
|
||||||
tmpStr, ok = section["privilege_mode"]
|
tmpStr, ok = section["privilege_mode"]
|
||||||
|
@ -178,6 +186,7 @@ func LoadConf(confFile string) (err error) {
|
||||||
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
|
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
|
||||||
}
|
}
|
||||||
} else if proxyClient.Type == "http" {
|
} else if proxyClient.Type == "http" {
|
||||||
|
// custom_domains
|
||||||
domainStr, ok := section["custom_domains"]
|
domainStr, ok := section["custom_domains"]
|
||||||
if ok {
|
if ok {
|
||||||
proxyClient.CustomDomains = strings.Split(domainStr, ",")
|
proxyClient.CustomDomains = strings.Split(domainStr, ",")
|
||||||
|
@ -191,6 +200,7 @@ func LoadConf(confFile string) (err error) {
|
||||||
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
|
return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
|
||||||
}
|
}
|
||||||
} else if proxyClient.Type == "https" {
|
} else if proxyClient.Type == "https" {
|
||||||
|
// custom_domains
|
||||||
domainStr, ok := section["custom_domains"]
|
domainStr, ok := section["custom_domains"]
|
||||||
if ok {
|
if ok {
|
||||||
proxyClient.CustomDomains = strings.Split(domainStr, ",")
|
proxyClient.CustomDomains = strings.Split(domainStr, ",")
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
type BaseConf struct {
|
type BaseConf struct {
|
||||||
Name string
|
Name string
|
||||||
AuthToken string
|
AuthToken string
|
||||||
Type string
|
Type string
|
||||||
UseEncryption bool
|
UseEncryption bool
|
||||||
UseGzip bool
|
UseGzip bool
|
||||||
PrivilegeMode bool
|
PrivilegeMode bool
|
||||||
PrivilegeToken string
|
PrivilegeToken string
|
||||||
PoolCount int64
|
PoolCount int64
|
||||||
|
HostHeaderRewrite string
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,12 +29,13 @@ type ControlReq struct {
|
||||||
PoolCount int64 `json:"pool_count"`
|
PoolCount int64 `json:"pool_count"`
|
||||||
|
|
||||||
// configures used if privilege_mode is enabled
|
// configures used if privilege_mode is enabled
|
||||||
PrivilegeMode bool `json:"privilege_mode"`
|
PrivilegeMode bool `json:"privilege_mode"`
|
||||||
PrivilegeKey string `json:"privilege_key"`
|
PrivilegeKey string `json:"privilege_key"`
|
||||||
ProxyType string `json:"proxy_type"`
|
ProxyType string `json:"proxy_type"`
|
||||||
RemotePort int64 `json:"remote_port"`
|
RemotePort int64 `json:"remote_port"`
|
||||||
CustomDomains []string `json:"custom_domains, omitempty"`
|
CustomDomains []string `json:"custom_domains, omitempty"`
|
||||||
Timestamp int64 `json:"timestamp"`
|
HostHeaderRewrite string `json:"host_header_rewrite"`
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ControlRes struct {
|
type ControlRes struct {
|
||||||
|
|
|
@ -15,12 +15,10 @@
|
||||||
package msg
|
package msg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"frp/models/config"
|
"frp/models/config"
|
||||||
|
@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
|
||||||
defer wait.Done()
|
defer wait.Done()
|
||||||
|
|
||||||
// we don't care about errors here
|
// we don't care about errors here
|
||||||
pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord)
|
pipeEncrypt(from, to, conf, needRecord)
|
||||||
}
|
}
|
||||||
|
|
||||||
decryptPipe := func(to *conn.Conn, from *conn.Conn) {
|
decryptPipe := func(to *conn.Conn, from *conn.Conn) {
|
||||||
|
@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
|
||||||
defer wait.Done()
|
defer wait.Done()
|
||||||
|
|
||||||
// we don't care about errors here
|
// we don't care about errors here
|
||||||
pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord)
|
pipeDecrypt(to, from, conf, needRecord)
|
||||||
}
|
}
|
||||||
|
|
||||||
wait.Add(2)
|
wait.Add(2)
|
||||||
|
@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// decrypt msg from reader, then write into writer
|
// decrypt msg from reader, then write into writer
|
||||||
func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
|
func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
|
||||||
laes := new(pcrypto.Pcrypto)
|
laes := new(pcrypto.Pcrypto)
|
||||||
key := conf.AuthToken
|
key := conf.AuthToken
|
||||||
if conf.PrivilegeMode {
|
if conf.PrivilegeMode {
|
||||||
|
@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
|
||||||
|
|
||||||
buf := make([]byte, 5*1024+4)
|
buf := make([]byte, 5*1024+4)
|
||||||
var left, res []byte
|
var left, res []byte
|
||||||
var cnt int
|
var cnt int = -1
|
||||||
|
|
||||||
// record
|
// record
|
||||||
var flowBytes int64 = 0
|
var flowBytes int64 = 0
|
||||||
|
@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
nreader := bufio.NewReader(r)
|
|
||||||
for {
|
for {
|
||||||
// there may be more than 1 package in variable
|
// there may be more than 1 package in variable
|
||||||
// and we read more bytes if unpkgMsg returns an error
|
// and we read more bytes if unpkgMsg returns an error
|
||||||
var newBuf []byte
|
var newBuf []byte
|
||||||
if cnt < 0 {
|
if cnt < 0 {
|
||||||
n, err := nreader.Read(buf)
|
n, err := r.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = w.Write(res)
|
_, err = w.WriteBytes(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// recvive msg from reader, then encrypt msg into writer
|
// recvive msg from reader, then encrypt msg into writer
|
||||||
func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
|
func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
|
||||||
laes := new(pcrypto.Pcrypto)
|
laes := new(pcrypto.Pcrypto)
|
||||||
key := conf.AuthToken
|
key := conf.AuthToken
|
||||||
if conf.PrivilegeMode {
|
if conf.PrivilegeMode {
|
||||||
|
@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
nreader := bufio.NewReader(r)
|
|
||||||
buf := make([]byte, 5*1024)
|
buf := make([]byte, 5*1024)
|
||||||
for {
|
for {
|
||||||
n, err := nreader.Read(buf)
|
n, err := r.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
res = pkgMsg(res)
|
res = pkgMsg(res)
|
||||||
_, err = w.Write(res)
|
_, err = w.WriteBytes(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,6 +65,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
|
||||||
p.BindAddr = BindAddr
|
p.BindAddr = BindAddr
|
||||||
p.ListenPort = req.RemotePort
|
p.ListenPort = req.RemotePort
|
||||||
p.CustomDomains = req.CustomDomains
|
p.CustomDomains = req.CustomDomains
|
||||||
|
p.HostHeaderRewrite = req.HostHeaderRewrite
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ func (p *ProxyServer) Init() {
|
||||||
|
|
||||||
func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
|
func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
|
||||||
if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
|
if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
|
||||||
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort {
|
p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(p.CustomDomains) != len(p2.CustomDomains) {
|
if len(p.CustomDomains) != len(p2.CustomDomains) {
|
||||||
|
@ -115,7 +116,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
|
||||||
p.listeners = append(p.listeners, l)
|
p.listeners = append(p.listeners, l)
|
||||||
} else if p.Type == "http" {
|
} else if p.Type == "http" {
|
||||||
for _, domain := range p.CustomDomains {
|
for _, domain := range p.CustomDomains {
|
||||||
l, err := VhostHttpMuxer.Listen(domain)
|
l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -123,7 +124,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
|
||||||
}
|
}
|
||||||
} else if p.Type == "https" {
|
} else if p.Type == "https" {
|
||||||
for _, domain := range p.CustomDomains {
|
for _, domain := range p.CustomDomains {
|
||||||
l, err := VhostHttpsMuxer.Listen(domain)
|
l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -160,14 +161,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// start another goroutine for join two connections between frpc and user
|
go func(userConn *conn.Conn) {
|
||||||
go func() {
|
|
||||||
workConn, err := p.getWorkConn()
|
workConn, err := p.getWorkConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userConn := c
|
|
||||||
// message will be transferred to another without modifying
|
// message will be transferred to another without modifying
|
||||||
// l means local, r means remote
|
// l means local, r means remote
|
||||||
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
|
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
|
||||||
|
@ -176,7 +175,8 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
|
||||||
metric.OpenConnection(p.Name)
|
metric.OpenConnection(p.Name)
|
||||||
needRecord := true
|
needRecord := true
|
||||||
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
|
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
|
||||||
}()
|
metric.OpenConnection(p.Name)
|
||||||
|
}(c)
|
||||||
}
|
}
|
||||||
}(listener)
|
}(listener)
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,6 +117,16 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if the tcpConn is different with c.TcpConn
|
||||||
|
// you should call c.Close() first
|
||||||
|
func (c *Conn) SetTcpConn(tcpConn net.Conn) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
c.TcpConn = tcpConn
|
||||||
|
c.closeFlag = false
|
||||||
|
c.Reader = bufio.NewReader(c.TcpConn)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) GetRemoteAddr() (addr string) {
|
func (c *Conn) GetRemoteAddr() (addr string) {
|
||||||
return c.TcpConn.RemoteAddr().String()
|
return c.TcpConn.RemoteAddr().String()
|
||||||
}
|
}
|
||||||
|
@ -125,6 +135,11 @@ func (c *Conn) GetLocalAddr() (addr string) {
|
||||||
return c.TcpConn.LocalAddr().String()
|
return c.TcpConn.LocalAddr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = c.Reader.Read(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) ReadLine() (buff string, err error) {
|
func (c *Conn) ReadLine() (buff string, err error) {
|
||||||
buff, err = c.Reader.ReadString('\n')
|
buff, err = c.Reader.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -138,10 +153,14 @@ func (c *Conn) ReadLine() (buff string, err error) {
|
||||||
return buff, err
|
return buff, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriteBytes(content []byte) (n int, err error) {
|
||||||
|
n, err = c.TcpConn.Write(content)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(content string) (err error) {
|
func (c *Conn) Write(content string) (err error) {
|
||||||
_, err = c.TcpConn.Write([]byte(content))
|
_, err = c.TcpConn.Write([]byte(content))
|
||||||
return err
|
return err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
|
|
@ -16,8 +16,12 @@ package vhost
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -42,6 +46,123 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
|
func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
|
||||||
mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
|
mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
|
||||||
return &HttpMuxer{mux}, err
|
return &HttpMuxer{mux}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
|
||||||
|
sc, rd := newShareConn(c.TcpConn)
|
||||||
|
var buff []byte
|
||||||
|
if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
|
||||||
|
return sc, err
|
||||||
|
}
|
||||||
|
err = sc.WriteBuff(buff)
|
||||||
|
return sc, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
|
||||||
|
buffer := make([]byte, 1024)
|
||||||
|
request.Read(buffer)
|
||||||
|
retBuffer, err := parseRequest(buffer, rewriteHost)
|
||||||
|
return retBuffer, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
|
||||||
|
tp := bytes.NewBuffer(org)
|
||||||
|
// First line: GET /index.html HTTP/1.0
|
||||||
|
var b []byte
|
||||||
|
if b, err = tp.ReadBytes('\n'); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req := new(http.Request)
|
||||||
|
// we invoked ReadRequest in GetHttpHostname before, so we ignore error
|
||||||
|
req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
|
||||||
|
rawurl := req.RequestURI
|
||||||
|
// CONNECT www.google.com:443 HTTP/1.1
|
||||||
|
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
|
||||||
|
if justAuthority {
|
||||||
|
rawurl = "http://" + rawurl
|
||||||
|
}
|
||||||
|
req.URL, _ = url.ParseRequestURI(rawurl)
|
||||||
|
if justAuthority {
|
||||||
|
// Strip the bogus "http://" back off.
|
||||||
|
req.URL.Scheme = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC2616: first case
|
||||||
|
// GET /index.html HTTP/1.1
|
||||||
|
// Host: www.google.com
|
||||||
|
if req.URL.Host == "" {
|
||||||
|
changedBuf, err := changeHostName(tp, rewriteHost)
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
buf.Write(b)
|
||||||
|
buf.Write(changedBuf)
|
||||||
|
return buf.Bytes(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC2616: second case
|
||||||
|
// GET http://www.google.com/index.html HTTP/1.1
|
||||||
|
// Host: doesntmatter
|
||||||
|
// In this case, any Host line is ignored.
|
||||||
|
hostPort := strings.Split(req.URL.Host, ":")
|
||||||
|
if len(hostPort) == 1 {
|
||||||
|
req.URL.Host = rewriteHost
|
||||||
|
} else if len(hostPort) == 2 {
|
||||||
|
req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
|
||||||
|
}
|
||||||
|
firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
buf.WriteString(firstLine)
|
||||||
|
tp.WriteTo(buf)
|
||||||
|
return buf.Bytes(), err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||||
|
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||||
|
s1 := strings.Index(line, " ")
|
||||||
|
s2 := strings.Index(line[s1+1:], " ")
|
||||||
|
if s1 < 0 || s2 < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s2 += s1 + 1
|
||||||
|
return line[:s1], line[s1+1 : s2], line[s2+1:], true
|
||||||
|
}
|
||||||
|
|
||||||
|
func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
|
||||||
|
retBuf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
peek := buff.Bytes()
|
||||||
|
for len(peek) > 0 {
|
||||||
|
i := bytes.IndexByte(peek, '\n')
|
||||||
|
if i < 3 {
|
||||||
|
// Not present (-1) or found within the next few bytes,
|
||||||
|
// implying we're at the end ("\r\n\r\n" or "\n\n")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
kv := peek[:i]
|
||||||
|
j := bytes.IndexByte(kv, ':')
|
||||||
|
if j < 0 {
|
||||||
|
return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
|
||||||
|
var hostHeader string
|
||||||
|
portPos := bytes.IndexByte(kv[j+1:], ':')
|
||||||
|
if portPos == -1 {
|
||||||
|
hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
|
||||||
|
} else {
|
||||||
|
hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
|
||||||
|
}
|
||||||
|
retBuf.WriteString(hostHeader)
|
||||||
|
peek = peek[i+1:]
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
retBuf.Write(peek[:i])
|
||||||
|
retBuf.WriteByte('\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
peek = peek[i+1:]
|
||||||
|
}
|
||||||
|
retBuf.Write(peek)
|
||||||
|
return retBuf.Bytes(), err
|
||||||
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ type HttpsMuxer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
|
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
|
||||||
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
|
mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
|
||||||
return &HttpsMuxer{mux}, err
|
return &HttpsMuxer{mux}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,37 +27,42 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type muxFunc func(*conn.Conn) (net.Conn, string, error)
|
type muxFunc func(*conn.Conn) (net.Conn, string, error)
|
||||||
|
type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
|
||||||
|
|
||||||
type VhostMuxer struct {
|
type VhostMuxer struct {
|
||||||
listener *conn.Listener
|
listener *conn.Listener
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
vhostFunc muxFunc
|
vhostFunc muxFunc
|
||||||
|
rewriteFunc hostRewriteFunc
|
||||||
registryMap map[string]*Listener
|
registryMap map[string]*Listener
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
|
func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
|
||||||
mux = &VhostMuxer{
|
mux = &VhostMuxer{
|
||||||
listener: listener,
|
listener: listener,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
vhostFunc: vhostFunc,
|
vhostFunc: vhostFunc,
|
||||||
|
rewriteFunc: rewriteFunc,
|
||||||
registryMap: make(map[string]*Listener),
|
registryMap: make(map[string]*Listener),
|
||||||
}
|
}
|
||||||
go mux.run()
|
go mux.run()
|
||||||
return mux, nil
|
return mux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *VhostMuxer) Listen(name string) (l *Listener, err error) {
|
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
|
||||||
|
func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
|
||||||
v.mutex.Lock()
|
v.mutex.Lock()
|
||||||
defer v.mutex.Unlock()
|
defer v.mutex.Unlock()
|
||||||
if _, exist := v.registryMap[name]; exist {
|
if _, exist := v.registryMap[name]; exist {
|
||||||
return nil, fmt.Errorf("name %s is already bound", name)
|
return nil, fmt.Errorf("domain name %s is already bound", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
l = &Listener{
|
l = &Listener{
|
||||||
name: name,
|
name: name,
|
||||||
mux: v,
|
rewriteHost: rewriteHost,
|
||||||
accept: make(chan *conn.Conn),
|
mux: v,
|
||||||
|
accept: make(chan *conn.Conn),
|
||||||
}
|
}
|
||||||
v.registryMap[name] = l
|
v.registryMap[name] = l
|
||||||
return l, nil
|
return l, nil
|
||||||
|
@ -105,15 +110,16 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
|
||||||
if err = sConn.SetDeadline(time.Time{}); err != nil {
|
if err = sConn.SetDeadline(time.Time{}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.TcpConn = sConn
|
c.SetTcpConn(sConn)
|
||||||
|
|
||||||
l.accept <- c
|
l.accept <- c
|
||||||
}
|
}
|
||||||
|
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
name string
|
name string
|
||||||
mux *VhostMuxer // for closing VhostMuxer
|
rewriteHost string
|
||||||
accept chan *conn.Conn
|
mux *VhostMuxer // for closing VhostMuxer
|
||||||
|
accept chan *conn.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listener) Accept() (*conn.Conn, error) {
|
func (l *Listener) Accept() (*conn.Conn, error) {
|
||||||
|
@ -121,6 +127,17 @@ func (l *Listener) Accept() (*conn.Conn, error) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Listener closed")
|
return nil, fmt.Errorf("Listener closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if rewriteFunc is exist and rewriteHost is set
|
||||||
|
// rewrite http requests with a modified host header
|
||||||
|
if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
|
||||||
|
fmt.Printf("host rewrite: %s\n", l.rewriteHost)
|
||||||
|
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("http host header rewrite failed")
|
||||||
|
}
|
||||||
|
conn.SetTcpConn(sConn)
|
||||||
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,6 +157,7 @@ type sharedConn struct {
|
||||||
buff *bytes.Buffer
|
buff *bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the bytes you read in io.Reader, will be reserved in sharedConn
|
||||||
func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
|
func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
|
||||||
sc := &sharedConn{
|
sc := &sharedConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
|
@ -166,3 +184,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) {
|
||||||
sc.Unlock()
|
sc.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
|
||||||
|
sc.buff.Reset()
|
||||||
|
_, err = sc.buff.Write(buffer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue