sock5验证加密、udp隧道、gzip、snnapy压缩

pull/1219/head
刘河 2018-12-03 23:03:25 +08:00
parent 83eb8dcb3c
commit 212d74bbc4
7 changed files with 345 additions and 104 deletions

View File

@ -70,7 +70,7 @@ func (s *TRPClient) process(c *Conn) error {
return err return err
} }
case WORK_CHAN: //隧道模式每次开启10个加快连接速度 case WORK_CHAN: //隧道模式每次开启10个加快连接速度
for i := 0; i < 10; i++ { for i := 0; i < 100; i++ {
go s.dealChan() go s.dealChan()
} }
case RES_MSG: case RES_MSG:
@ -86,6 +86,9 @@ func (s *TRPClient) process(c *Conn) error {
func (s *TRPClient) dealChan() error { func (s *TRPClient) dealChan() error {
//创建一个tcp连接 //创建一个tcp连接
conn, err := net.Dial("tcp", s.svrAddr) conn, err := net.Dial("tcp", s.svrAddr)
if err != nil {
return err
}
//验证 //验证
if _, err := conn.Write(getverifyval()); err != nil { if _, err := conn.Write(getverifyval()); err != nil {
return err return err
@ -95,36 +98,31 @@ func (s *TRPClient) dealChan() error {
c.SetAlive() c.SetAlive()
//写标志 //写标志
c.wChan() c.wChan()
//获取连接的host //获取连接的host type(tcp or udp)
host, err := c.GetHostFromConn() typeStr, host, err := c.GetHostFromConn()
if err != nil { if err != nil {
return err return err
} }
//与目标建立连接 //与目标建立连接
server, err := net.Dial("tcp", host) server, err := net.Dial(typeStr, host)
if err != nil { if err != nil {
fmt.Println(err) log.Println(err)
return err return err
} }
//创建成功后io.copy go relay(NewConn(server), c, DataDecode)
go relay(server, c.conn) relay(c, NewConn(server), DataEncode)
relay(c.conn, server)
return nil return nil
} }
//http模式处理 //http模式处理
func (s *TRPClient) dealHttp(c *Conn) error { func (s *TRPClient) dealHttp(c *Conn) error {
nlen, err := c.GetLen() buf := make([]byte, 1024*32)
n, err := c.ReadFromCompress(buf, DataDecode)
if err != nil { if err != nil {
c.wError() c.wError()
return err return err
} }
raw, err := c.ReadLen(int(nlen)) req, err := DecodeRequest(buf[:n])
if err != nil {
c.wError()
return err
}
req, err := DecodeRequest(raw)
if err != nil { if err != nil {
c.wError() c.wError()
return err return err
@ -134,7 +132,8 @@ func (s *TRPClient) dealHttp(c *Conn) error {
c.wError() c.wError()
return err return err
} }
n, err := c.Write(respBytes) c.wSign()
n, err = c.WriteCompress(respBytes, DataEncode)
if err != nil { if err != nil {
return err return err
} }

77
conn.go
View File

@ -2,10 +2,13 @@ package main
import ( import (
"bytes" "bytes"
"compress/gzip"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/golang/snappy"
"io" "io"
"log"
"net" "net"
"net/url" "net/url"
"regexp" "regexp"
@ -47,7 +50,7 @@ func (s *Conn) ReadLen(len int) ([]byte, error) {
//获取长度 //获取长度
func (s *Conn) GetLen() (int, error) { func (s *Conn) GetLen() (int, error) {
val := make([]byte, 4) val := make([]byte, 4)
_, err := s.conn.Read(val) _, err := s.Read(val)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -58,6 +61,21 @@ func (s *Conn) GetLen() (int, error) {
return int(nlen), nil return int(nlen), nil
} }
//写入长度
func (s *Conn) WriteLen(buf []byte) (int, error) {
raw := bytes.NewBuffer([]byte{})
if err := binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
log.Println(err)
return 0, err
}
if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
log.Println(err)
return 0, err
}
return s.Write(raw.Bytes())
}
//读取flag //读取flag
func (s *Conn) ReadFlag() (string, error) { func (s *Conn) ReadFlag() (string, error) {
val := make([]byte, 4) val := make([]byte, 4)
@ -69,22 +87,30 @@ func (s *Conn) ReadFlag() (string, error) {
} }
//读取host //读取host
func (s *Conn) GetHostFromConn() (string, error) { func (s *Conn) GetHostFromConn() (typeStr string, host string, err error) {
ltype := make([]byte, 3)
_, err = s.Read(ltype)
if err != nil {
return
}
typeStr = string(ltype)
len, err := s.GetLen() len, err := s.GetLen()
if err != nil { if err != nil {
return "", err return
} }
hostByte := make([]byte, len) hostByte := make([]byte, len)
_, err = s.conn.Read(hostByte) _, err = s.conn.Read(hostByte)
if err != nil { if err != nil {
return "", err return
} }
return string(hostByte), nil host = string(hostByte)
return
} }
//获取host //写tcp host
func (s *Conn) WriteHost(host string) (int, error) { func (s *Conn) WriteHost(ltype string, host string) (int, error) {
raw := bytes.NewBuffer([]byte{}) raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, []byte(ltype))
binary.Write(raw, binary.LittleEndian, int32(len([]byte(host)))) binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
binary.Write(raw, binary.LittleEndian, []byte(host)) binary.Write(raw, binary.LittleEndian, []byte(host))
return s.Write(raw.Bytes()) return s.Write(raw.Bytes())
@ -139,10 +165,47 @@ func (s *Conn) Write(b []byte) (int, error) {
func (s *Conn) Read(b []byte) (int, error) { func (s *Conn) Read(b []byte) (int, error) {
return s.conn.Read(b) return s.conn.Read(b)
} }
func (s *Conn) ReadFromCompress(b []byte, compress int) (int, error) {
switch compress {
case COMPRESS_GZIP_DECODE:
r, err := gzip.NewReader(s)
if err != nil {
return 0, err
}
return r.Read(b)
case COMPRESS_SNAPY_DECODE:
r := snappy.NewReader(s)
return r.Read(b)
case COMPRESS_NONE:
return s.Read(b)
}
return 0, nil
}
func (s *Conn) WriteCompress(b []byte, compress int) (n int, err error) {
switch compress {
case COMPRESS_GZIP_ENCODE:
w := gzip.NewWriter(s)
if n, err = w.Write(b); err == nil {
w.Flush()
}
case COMPRESS_SNAPY_ENCODE:
w := snappy.NewBufferedWriter(s)
if n, err = w.Write(b); err == nil {
w.Flush()
}
case COMPRESS_NONE:
n, err = s.Write(b)
}
return
}
func (s *Conn) wError() { func (s *Conn) wError() {
s.conn.Write([]byte(RES_MSG)) s.conn.Write([]byte(RES_MSG))
} }
func (s *Conn) wSign() {
s.conn.Write([]byte(RES_SIGN))
}
func (s *Conn) wMain() { func (s *Conn) wMain() {
s.conn.Write([]byte(WORK_MAIN)) s.conn.Write([]byte(WORK_MAIN))

25
main.go
View File

@ -11,15 +11,33 @@ var (
tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口") tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口")
httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口当为mode为client时为转发至本地客户端的端口") httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口当为mode为client时为转发至本地客户端的端口")
rpMode = flag.String("mode", "client", "启动模式可选为client、server") rpMode = flag.String("mode", "client", "启动模式可选为client、server")
tunnelTarget = flag.String("target", "10.1.50.203:80", "tunnel模式远程目标") tunnelTarget = flag.String("target", "10.1.50.203:80", "远程目标")
verifyKey = flag.String("vkey", "", "验证密钥") verifyKey = flag.String("vkey", "", "验证密钥")
u = flag.String("u", "", "sock5验证用户名")
p = flag.String("p", "", "sock5验证密码")
compress = flag.String("compress", "", "数据压缩gizp|snappy")
config Config config Config
err error err error
DataEncode int
DataDecode int
) )
func main() { func main() {
flag.Parse() flag.Parse()
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
switch *compress {
case "":
DataDecode = COMPRESS_NONE
DataEncode = COMPRESS_NONE
case "gzip":
DataDecode = COMPRESS_GZIP_DECODE
DataEncode = COMPRESS_GZIP_ENCODE
case "snnapy":
DataDecode = COMPRESS_SNAPY_DECODE
DataEncode = COMPRESS_SNAPY_ENCODE
default:
log.Fatalln("数据压缩格式错误")
}
if *rpMode == "client" { if *rpMode == "client" {
JsonParse := NewJsonStruct() JsonParse := NewJsonStruct()
config, err = JsonParse.Load(*configPath) config, err = JsonParse.Load(*configPath)
@ -50,11 +68,14 @@ func main() {
svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessTunnel) svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessTunnel)
svr.Start() svr.Start()
} else if *rpMode == "sock5Server" { } else if *rpMode == "sock5Server" {
svr := NewSock5ModeServer(*tcpPort, *httpPort) svr := NewSock5ModeServer(*tcpPort, *httpPort, *u, *p)
svr.Start() svr.Start()
} else if *rpMode == "httpProxyServer" { } else if *rpMode == "httpProxyServer" {
svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessHttp) svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessHttp)
svr.Start() svr.Start()
} else if *rpMode == "udpServer" {
svr := NewUdpModeServer(*tcpPort, *httpPort, *tunnelTarget)
svr.Start()
} }
} }
} }

View File

@ -74,7 +74,8 @@ func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error {
if err != nil { if err != nil {
return err return err
} }
c, err := conn.Write(raw) conn.wSign()
c, err := conn.WriteCompress(raw, DataEncode)
if err != nil { if err != nil {
return err return err
} }
@ -92,15 +93,12 @@ func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error {
} }
switch flags { switch flags {
case RES_SIGN: case RES_SIGN:
nlen, err := c.GetLen() buf := make([]byte, 1024*32)
n, err := c.ReadFromCompress(buf, DataDecode)
if err != nil { if err != nil {
return err return err
} }
raw, err := c.ReadLen(nlen) resp, err := DecodeResponse(buf[:n])
if err != nil {
return err
}
resp, err := DecodeResponse(raw)
if err != nil { if err != nil {
return err return err
} }
@ -176,12 +174,12 @@ func (s *TunnelModeServer) startTunnelServer() {
func ProcessTunnel(c *Conn, s *TunnelModeServer) error { func ProcessTunnel(c *Conn, s *TunnelModeServer) error {
retry: retry:
link := s.GetTunnel() link := s.GetTunnel()
if _, err := link.WriteHost(s.tunnelTarget); err != nil { if _, err := link.WriteHost("tcp", s.tunnelTarget); err != nil {
link.Close() link.Close()
goto retry goto retry
} }
go relay(link.conn, c.conn) go relay(link, c, DataEncode)
relay(c.conn, link.conn) relay(c, link, DataDecode)
return nil return nil
} }
@ -194,16 +192,16 @@ func ProcessHttp(c *Conn, s *TunnelModeServer) error {
} }
retry: retry:
link := s.GetTunnel() link := s.GetTunnel()
if _, err := link.WriteHost(addr); err != nil { if _, err := link.WriteHost("tcp", addr); err != nil {
link.Close() link.Close()
goto retry goto retry
} }
if method == "CONNECT" { if method == "CONNECT" {
fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n") fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n")
} else { } else {
link.Write(rb) link.WriteCompress(rb, DataEncode)
} }
go relay(link.conn, c.conn) go relay(link, c, DataEncode)
relay(c.conn, link.conn) relay(c, link, DataDecode)
return nil return nil
} }

109
sock5.go
View File

@ -10,9 +10,9 @@ import (
) )
const ( const (
ipV4 = 1 ipV4 = 1
domainName = 3 domainName = 3
ipV6 = 4 ipV6 = 4
connectMethod = 1 connectMethod = 1
bindMethod = 2 bindMethod = 2
associateMethod = 3 associateMethod = 3
@ -35,9 +35,19 @@ const (
addrTypeNotSupported addrTypeNotSupported
) )
const (
UserPassAuth = uint8(2)
userAuthVersion = uint8(1)
authSuccess = uint8(0)
authFailure = uint8(1)
)
type Sock5ModeServer struct { type Sock5ModeServer struct {
Tunnel Tunnel
httpPort int httpPort int
u string //用户名
p string //密码
isVerify bool
} }
func (s *Sock5ModeServer) handleRequest(c net.Conn) { func (s *Sock5ModeServer) handleRequest(c net.Conn) {
@ -119,37 +129,31 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) (proxyConn *Conn,
var port uint16 var port uint16
binary.Read(c, binary.BigEndian, &port) binary.Read(c, binary.BigEndian, &port)
// connect to host // connect to host
addr := net.JoinHostPort(host, strconv.Itoa(int(port))) addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
//取出一个连接 client := s.GetTunnel()
if len(s.tunnelList) < 10 { //新建通道
go s.newChan()
}
client := <-s.tunnelList
s.sendReply(c, succeeded) s.sendReply(c, succeeded)
_, err = client.WriteHost(addr) var ltype string
if command == associateMethod {
ltype = "udp"
} else {
ltype = "tcp"
}
_, err = client.WriteHost(ltype, addr)
return client, nil return client, nil
} }
func (s *Sock5ModeServer) handleConnect(c net.Conn) { func (s *Sock5ModeServer) handleConnect(c net.Conn) {
proxyConn, err := s.doConnect(c, connectMethod) proxyConn, err := s.doConnect(c, connectMethod)
if err != nil { if err != nil {
log.Println(err)
c.Close() c.Close()
} else { } else {
go io.Copy(c, proxyConn) go relay(proxyConn, NewConn(c), DataEncode)
go io.Copy(proxyConn, c) go relay(NewConn(c), proxyConn, DataDecode)
} }
} }
func (s *Sock5ModeServer) relay(in, out net.Conn) {
if _, err := io.Copy(in, out); err != nil {
log.Println("copy error", err)
}
in.Close() // will trigger an error in the other relay, then call out.Close()
}
// passive mode // passive mode
func (s *Sock5ModeServer) handleBind(c net.Conn) { func (s *Sock5ModeServer) handleBind(c net.Conn) {
} }
@ -177,8 +181,8 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) {
if err != nil { if err != nil {
c.Close() c.Close()
} else { } else {
go io.Copy(c, proxyConn) go relay(proxyConn, NewConn(c), DataEncode)
go io.Copy(proxyConn, c) go relay(NewConn(c), proxyConn, DataDecode)
} }
} }
@ -203,14 +207,56 @@ func (s *Sock5ModeServer) handleNewConn(c net.Conn) {
c.Close() c.Close()
return return
} }
// no authentication required for now if s.isVerify {
buf[1] = 0 buf[1] = UserPassAuth
// send a METHOD selection message c.Write(buf)
c.Write(buf) if err := s.Auth(c); err != nil {
c.Close()
log.Println("验证失败:", err)
return
}
} else {
buf[1] = 0
c.Write(buf)
}
s.handleRequest(c) s.handleRequest(c)
} }
func (s *Sock5ModeServer) Auth(c net.Conn) error {
header := []byte{0, 0}
if _, err := io.ReadAtLeast(c, header, 2); err != nil {
return err
}
if header[0] != userAuthVersion {
return errors.New("验证方式不被支持")
}
userLen := int(header[1])
user := make([]byte, userLen)
if _, err := io.ReadAtLeast(c, user, userLen); err != nil {
return err
}
if _, err := c.Read(header[:1]); err != nil {
return errors.New("密码长度获取错误")
}
passLen := int(header[0])
pass := make([]byte, passLen)
if _, err := io.ReadAtLeast(c, pass, passLen); err != nil {
return err
}
if string(pass) == s.p && string(user) == s.u {
if _, err := c.Write([]byte{userAuthVersion, authSuccess}); err != nil {
return err
}
return nil
} else {
if _, err := c.Write([]byte{userAuthVersion, authFailure}); err != nil {
return err
}
return errors.New("验证不通过")
}
return errors.New("未知错误")
}
func (s *Sock5ModeServer) Start() { func (s *Sock5ModeServer) Start() {
l, err := net.Listen("tcp", ":"+strconv.Itoa(s.httpPort)) l, err := net.Listen("tcp", ":"+strconv.Itoa(s.httpPort))
if err != nil { if err != nil {
@ -226,11 +272,18 @@ func (s *Sock5ModeServer) Start() {
} }
} }
func NewSock5ModeServer(tcpPort, httpPort int) *Sock5ModeServer { func NewSock5ModeServer(tcpPort, httpPort int, u, p string) *Sock5ModeServer {
s := new(Sock5ModeServer) s := new(Sock5ModeServer)
s.tunnelPort = tcpPort s.tunnelPort = tcpPort
s.httpPort = httpPort s.httpPort = httpPort
s.tunnelList = make(chan *Conn, 1000) s.tunnelList = make(chan *Conn, 1000)
s.signalList = make(chan *Conn, 10) s.signalList = make(chan *Conn, 10)
if u != "" && p != "" {
s.isVerify = true
s.u = u
s.p = p
} else {
s.isVerify = false
}
return s return s
} }

80
udp.go Executable file
View File

@ -0,0 +1,80 @@
package main
import (
"io"
"log"
"net"
"time"
)
type UdpModeServer struct {
Tunnel
udpPort int //监听的udp端口
tunnelTarget string //udp目标地址
listener *net.UDPConn
udpMap map[string]*Conn
}
func NewUdpModeServer(tcpPort, udpPort int, tunnelTarget string) *UdpModeServer {
s := new(UdpModeServer)
s.tunnelPort = tcpPort
s.udpPort = udpPort
s.tunnelTarget = tunnelTarget
s.tunnelList = make(chan *Conn, 1000)
s.signalList = make(chan *Conn, 10)
s.udpMap = make(map[string]*Conn)
return s
}
//开始
func (s *UdpModeServer) Start() (error) {
err := s.StartTunnel()
if err != nil {
log.Fatalln("启动失败!", err)
return err
}
s.startTunnelServer()
return nil
}
//udp监听
func (s *UdpModeServer) startTunnelServer() {
s.listener, err = net.ListenUDP("udp", &net.UDPAddr{net.ParseIP("0.0.0.0"), s.udpPort, ""})
if err != nil {
log.Fatalln(err)
}
data := make([]byte, 1472) //udp数据包大小
for {
n, addr, err := s.listener.ReadFromUDP(data)
if err != nil {
log.Println(err)
continue
}
go s.process(addr, data[:n])
}
}
func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) {
conn := s.GetTunnel()
conn.WriteHost("udp", s.tunnelTarget)
go func() {
for {
buf := make([]byte, 1024)
conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3)))
n, err := conn.ReadFromCompress(buf, DataDecode)
if err != nil || err == io.EOF {
conn.Close()
break
}
_, err = s.listener.WriteToUDP(buf[:n], addr)
if err != nil {
conn.Close()
break
}
}
}()
if _, err = conn.WriteCompress(data, DataEncode); err != nil {
conn.Close()
}
}

101
util.go
View File

@ -7,8 +7,9 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/golang/snappy"
"io" "io"
"net" "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@ -20,6 +21,14 @@ var (
disabledRedirect = errors.New("disabled redirect.") disabledRedirect = errors.New("disabled redirect.")
) )
const (
COMPRESS_NONE = iota
COMPRESS_SNAPY_ENCODE
COMPRESS_SNAPY_DECODE
COMPRESS_GZIP_ENCODE
COMPRESS_GZIP_DECODE
)
func BadRequest(w http.ResponseWriter) { func BadRequest(w http.ResponseWriter) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
} }
@ -46,30 +55,20 @@ func GetEncodeResponse(req *http.Request) ([]byte, error) {
return respBytes, nil return respBytes, nil
} }
// 将request 的处理 // 将request转为bytes
func EncodeRequest(r *http.Request) ([]byte, error) { func EncodeRequest(r *http.Request) ([]byte, error) {
raw := bytes.NewBuffer([]byte{}) raw := bytes.NewBuffer([]byte{})
// 写签名
binary.Write(raw, binary.LittleEndian, []byte("sign"))
reqBytes, err := httputil.DumpRequest(r, true) reqBytes, err := httputil.DumpRequest(r, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 写body数据长度 + 1
binary.Write(raw, binary.LittleEndian, int32(len(reqBytes)+1))
// 判断是否为http或者https的标识1字节
binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https")) binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https"))
if err := binary.Write(raw, binary.LittleEndian, reqBytes); err != nil { binary.Write(raw, binary.LittleEndian, reqBytes)
return nil, err
}
return raw.Bytes(), nil return raw.Bytes(), nil
} }
// 将字节转为request // 将字节转为request
func DecodeRequest(data []byte) (*http.Request, error) { func DecodeRequest(data []byte) (*http.Request, error) {
if len(data) <= 100 {
return nil, errors.New("待解码的字节长度太小")
}
req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:]))) req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:])))
if err != nil { if err != nil {
return nil, err return nil, err
@ -84,42 +83,25 @@ func DecodeRequest(data []byte) (*http.Request, error) {
scheme = "https" scheme = "https"
} }
req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI)) req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI))
fmt.Println(req.URL)
req.RequestURI = "" req.RequestURI = ""
return req, nil return req, nil
} }
//// 将response转为字节 //// 将response转为字节
func EncodeResponse(r *http.Response) ([]byte, error) { func EncodeResponse(r *http.Response) ([]byte, error) {
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, []byte(RES_SIGN))
respBytes, err := httputil.DumpResponse(r, true) respBytes, err := httputil.DumpResponse(r, true)
if config.Replace == 1 {
respBytes = replaceHost(respBytes)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
var buf bytes.Buffer if config.Replace == 1 {
zw := gzip.NewWriter(&buf) respBytes = replaceHost(respBytes)
zw.Write(respBytes)
zw.Close()
binary.Write(raw, binary.LittleEndian, int32(len(buf.Bytes())))
if err := binary.Write(raw, binary.LittleEndian, buf.Bytes()); err != nil {
fmt.Println(err)
return nil, err
} }
return raw.Bytes(), nil return respBytes, nil
} }
// 将字节转为response // 将字节转为response
func DecodeResponse(data []byte) (*http.Response, error) { func DecodeResponse(data []byte) (*http.Response, error) {
zr, err := gzip.NewReader(bytes.NewReader(data)) resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), nil)
if err != nil {
return nil, err
}
defer zr.Close()
resp, err := http.ReadResponse(bufio.NewReader(zr), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -144,7 +126,52 @@ func replaceHost(resp []byte) []byte {
return []byte(str) return []byte(str)
} }
func relay(in, out net.Conn) { func relay(in, out *Conn, compressType int) {
io.Copy(in, out); buf := make([]byte, 32*1024)
in.Close() switch compressType {
case COMPRESS_GZIP_ENCODE:
w := gzip.NewWriter(in)
for {
n, err := out.Read(buf)
if err != nil || err == io.EOF {
break
}
if _, err = w.Write(buf[:n]); err != nil {
break
}
if err = w.Flush(); err != nil {
log.Println(err)
break
}
}
w.Close()
case COMPRESS_SNAPY_ENCODE:
w := snappy.NewBufferedWriter(in)
for {
n, err := out.Read(buf)
if err != nil || err == io.EOF {
break
}
if _, err = w.Write(buf[:n]); err != nil {
break
}
if err = w.Flush(); err != nil {
log.Println(err)
break
}
}
w.Close()
case COMPRESS_GZIP_DECODE:
r, err := gzip.NewReader(out)
if err != nil {
return
}
io.Copy(in, r)
case COMPRESS_SNAPY_DECODE:
r := snappy.NewReader(out)
io.Copy(in, r)
default:
io.Copy(in, out)
}
out.Close()
} }