nps/utils/conn.go

371 lines
7.5 KiB
Go
Raw Normal View History

2019-01-09 12:33:00 +00:00
package utils
2018-11-29 11:55:24 +00:00
import (
2018-12-11 08:37:12 +00:00
"bufio"
2018-11-29 11:55:24 +00:00
"bytes"
"encoding/binary"
"errors"
"github.com/golang/snappy"
2019-01-03 16:21:23 +00:00
"io"
"log"
2018-11-29 11:55:24 +00:00
"net"
2018-12-11 08:37:12 +00:00
"net/http"
2018-11-30 18:38:29 +00:00
"net/url"
2018-12-11 08:37:12 +00:00
"strconv"
2018-11-30 18:38:29 +00:00
"strings"
2018-11-29 11:55:24 +00:00
"time"
)
2019-01-09 12:33:00 +00:00
const cryptKey = "1234567812345678"
2019-01-02 17:44:45 +00:00
type CryptConn struct {
conn net.Conn
crypt bool
}
func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
c := new(CryptConn)
c.conn = conn
c.crypt = crypt
return c
}
2019-01-03 16:21:23 +00:00
//加密写
2019-01-02 17:44:45 +00:00
func (s *CryptConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
return
}
2019-01-06 17:52:54 +00:00
}
if b, err = GetLenBytes(b); err != nil {
return
2019-01-02 17:44:45 +00:00
}
_, err = s.conn.Write(b)
return
}
2019-01-03 16:21:23 +00:00
//解密读
2019-01-02 17:44:45 +00:00
func (s *CryptConn) Read(b []byte) (n int, err error) {
2019-01-05 19:16:46 +00:00
defer func() {
2019-01-06 17:52:54 +00:00
if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
2019-01-05 19:16:46 +00:00
err = io.EOF
n = 0
}
}()
2019-01-06 17:52:54 +00:00
var lens int
2019-01-11 10:10:37 +00:00
var buf []byte
2019-01-12 16:09:12 +00:00
var rb []byte
2019-01-06 17:52:54 +00:00
c := NewConn(s.conn)
if lens, err = c.GetLen(); err != nil {
return
}
if buf, err = c.ReadLen(lens); err != nil {
return
}
2019-01-02 17:44:45 +00:00
if s.crypt {
2019-01-12 16:09:12 +00:00
if rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
2019-01-02 17:44:45 +00:00
return
}
2019-01-06 17:52:54 +00:00
} else {
2019-01-12 16:09:12 +00:00
rb = buf
2019-01-02 17:44:45 +00:00
}
2019-01-12 16:09:12 +00:00
copy(b, rb)
n = len(rb)
return
2019-01-02 17:44:45 +00:00
}
2018-12-11 08:37:12 +00:00
type SnappyConn struct {
2019-01-02 17:44:45 +00:00
w *snappy.Writer
r *snappy.Reader
crypt bool
2018-12-11 08:37:12 +00:00
}
2019-01-02 17:44:45 +00:00
func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
2018-12-11 08:37:12 +00:00
c := new(SnappyConn)
c.w = snappy.NewBufferedWriter(conn)
c.r = snappy.NewReader(conn)
2019-01-02 17:44:45 +00:00
c.crypt = crypt
2018-12-11 08:37:12 +00:00
return c
}
2019-01-03 16:21:23 +00:00
//snappy压缩写 包含加密
2018-12-11 08:37:12 +00:00
func (s *SnappyConn) Write(b []byte) (n int, err error) {
2019-01-02 17:44:45 +00:00
n = len(b)
if s.crypt {
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
log.Println("encode crypt error:", err)
return
}
}
if _, err = s.w.Write(b); err != nil {
2018-12-11 08:37:12 +00:00
return
}
err = s.w.Flush()
return
}
2019-01-03 16:21:23 +00:00
//snappy压缩读 包含解密
2018-12-11 08:37:12 +00:00
func (s *SnappyConn) Read(b []byte) (n int, err error) {
2019-01-05 19:16:46 +00:00
defer func() {
2019-01-06 17:52:54 +00:00
if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
2019-01-05 19:16:46 +00:00
err = io.EOF
n = 0
}
}()
2019-01-11 10:10:37 +00:00
buf := bufPool.Get().([]byte)
2019-01-13 12:49:45 +00:00
defer bufPool.Put(buf)
2019-01-11 10:10:37 +00:00
if n, err = s.r.Read(buf); err != nil {
2019-01-02 17:44:45 +00:00
return
}
2019-01-11 10:10:37 +00:00
var bs []byte
2019-01-02 17:44:45 +00:00
if s.crypt {
2019-01-11 10:10:37 +00:00
if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
2019-01-02 17:44:45 +00:00
log.Println("decode crypt error:", err)
return
}
2019-01-11 10:10:37 +00:00
} else {
bs = buf[:n]
2019-01-02 17:44:45 +00:00
}
2019-01-11 10:10:37 +00:00
n = len(bs)
copy(b, bs)
2019-01-02 17:44:45 +00:00
return
2018-12-11 08:37:12 +00:00
}
2018-11-29 11:55:24 +00:00
type Conn struct {
2019-01-09 12:33:00 +00:00
Conn net.Conn
2018-11-29 11:55:24 +00:00
}
2019-01-03 16:21:23 +00:00
//new conn
2018-11-29 11:55:24 +00:00
func NewConn(conn net.Conn) *Conn {
c := new(Conn)
2019-01-09 12:33:00 +00:00
c.Conn = conn
2018-11-29 11:55:24 +00:00
return c
}
2019-01-03 16:21:23 +00:00
//读取指定长度内容
2019-01-06 17:52:54 +00:00
func (s *Conn) ReadLen(cLen int) ([]byte, error) {
2019-01-11 10:10:37 +00:00
if cLen > poolSize {
return nil, errors.New("长度错误" + strconv.Itoa(cLen))
2019-01-06 17:52:54 +00:00
}
2019-01-12 16:09:12 +00:00
var buf []byte
if cLen <= poolSizeSmall {
buf = bufPoolSmall.Get().([]byte)[:cLen]
2019-01-13 12:49:45 +00:00
defer bufPoolSmall.Put(buf)
2019-01-12 16:09:12 +00:00
} else {
buf = bufPool.Get().([]byte)[:cLen]
2019-01-13 12:49:45 +00:00
defer bufPool.Put(buf)
2019-01-12 16:09:12 +00:00
}
2019-01-06 17:52:54 +00:00
if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
2019-01-02 17:44:45 +00:00
return buf, errors.New("读取指定长度错误" + err.Error())
2018-11-29 11:55:24 +00:00
}
2019-01-02 17:44:45 +00:00
return buf, nil
2018-11-29 11:55:24 +00:00
}
//获取长度
func (s *Conn) GetLen() (int, error) {
2019-01-03 16:21:23 +00:00
val, err := s.ReadLen(4)
if err != nil {
2018-11-29 11:55:24 +00:00
return 0, err
}
2019-01-02 17:44:45 +00:00
return GetLenByBytes(val)
2018-11-29 11:55:24 +00:00
}
2019-01-03 16:21:23 +00:00
//写入长度+内容 粘包
func (s *Conn) WriteLen(buf []byte) (int, error) {
2019-01-02 17:44:45 +00:00
var b []byte
2019-01-09 12:33:00 +00:00
var err error
2019-01-02 17:44:45 +00:00
if b, err = GetLenBytes(buf); err != nil {
return 0, err
}
2019-01-02 17:44:45 +00:00
return s.Write(b)
}
2018-11-29 11:55:24 +00:00
//读取flag
func (s *Conn) ReadFlag() (string, error) {
2019-01-03 16:21:23 +00:00
val, err := s.ReadLen(4)
if err != nil {
2018-11-29 11:55:24 +00:00
return "", err
}
return string(val), err
}
2018-12-11 08:37:12 +00:00
//读取host 连接地址 压缩类型
2019-01-05 19:16:46 +00:00
func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
2018-12-06 12:45:14 +00:00
retry:
2019-01-03 16:21:23 +00:00
lType, err := s.ReadLen(3)
if err != nil {
return
}
2019-01-03 16:21:23 +00:00
if typeStr = string(lType); typeStr == TEST_FLAG {
2019-01-05 19:16:46 +00:00
en, de, crypt, mux = s.GetConnInfoFromConn()
2018-12-06 12:45:14 +00:00
goto retry
2019-01-11 13:07:49 +00:00
} else if typeStr != CONN_TCP && typeStr != CONN_UDP {
err = errors.New("unknown conn type")
return
2018-12-06 12:45:14 +00:00
}
2019-01-03 16:21:23 +00:00
cLen, err := s.GetLen()
2019-01-11 13:07:49 +00:00
if err != nil || cLen > poolSize {
return
2018-11-29 11:55:24 +00:00
}
2019-01-03 16:21:23 +00:00
hostByte, err := s.ReadLen(cLen)
if err != nil {
return
2018-11-29 11:55:24 +00:00
}
host = string(hostByte)
return
2018-11-29 11:55:24 +00:00
}
2018-12-11 08:37:12 +00:00
//写连接类型 和 host地址
func (s *Conn) WriteHost(ltype string, host string) (int, error) {
2018-11-29 11:55:24 +00:00
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, []byte(ltype))
2018-11-29 11:55:24 +00:00
binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
binary.Write(raw, binary.LittleEndian, []byte(host))
return s.Write(raw.Bytes())
}
//设置连接为长连接
func (s *Conn) SetAlive() {
2019-01-09 12:33:00 +00:00
conn := s.Conn.(*net.TCPConn)
2018-11-29 11:55:24 +00:00
conn.SetReadDeadline(time.Time{})
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
}
2019-01-10 01:33:05 +00:00
//从tcp报文中解析出host连接类型等 TODO 多种情况
2018-12-30 14:36:15 +00:00
func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
2019-01-11 17:22:53 +00:00
var b [32 * 1024]byte
var n int
if n, err = s.Read(b[:]); err != nil {
2018-11-30 18:38:29 +00:00
return
}
2019-01-11 17:22:53 +00:00
rb = b[:n]
r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
2018-11-30 18:38:29 +00:00
if err != nil {
return
}
2018-12-11 08:37:12 +00:00
hostPortURL, err := url.Parse(r.Host)
2018-11-30 18:38:29 +00:00
if err != nil {
2019-01-11 17:22:53 +00:00
address = r.Host
err = nil
2018-11-30 18:38:29 +00:00
return
}
if hostPortURL.Opaque == "443" { //https访问
2018-12-11 08:37:12 +00:00
address = r.Host + ":443"
2018-11-30 18:38:29 +00:00
} else { //http访问
if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口 默认80
2018-12-11 08:37:12 +00:00
address = r.Host + ":80"
2018-11-30 18:38:29 +00:00
} else {
2018-12-11 08:37:12 +00:00
address = r.Host
2018-11-30 18:38:29 +00:00
}
}
return
}
2019-01-03 16:21:23 +00:00
//单独读(加密|压缩)
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
if COMPRESS_SNAPY_DECODE == compress {
2019-01-09 12:33:00 +00:00
return NewSnappyConn(s.Conn, crypt).Read(b)
}
2019-01-09 12:33:00 +00:00
return NewCryptConn(s.Conn, crypt).Read(b)
}
2019-01-03 16:21:23 +00:00
//单独写(加密|压缩)
func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
if COMPRESS_SNAPY_ENCODE == compress {
2019-01-09 12:33:00 +00:00
return NewSnappyConn(s.Conn, crypt).Write(b)
}
2019-01-09 12:33:00 +00:00
return NewCryptConn(s.Conn, crypt).Write(b)
}
2018-11-29 11:55:24 +00:00
2019-01-03 16:21:23 +00:00
//写压缩方式,加密
2019-01-05 19:16:46 +00:00
func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
2018-12-11 08:37:12 +00:00
}
2019-01-03 16:21:23 +00:00
//获取压缩方式,是否加密
2019-01-05 19:16:46 +00:00
func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
buf, err := s.ReadLen(4)
2019-01-03 16:21:23 +00:00
if err != nil {
return
}
2018-12-11 08:37:12 +00:00
en, _ = strconv.Atoi(string(buf[0]))
de, _ = strconv.Atoi(string(buf[1]))
2019-01-02 17:44:45 +00:00
crypt = GetBoolByStr(string(buf[2]))
2019-01-05 19:16:46 +00:00
mux = GetBoolByStr(string(buf[3]))
2018-12-11 08:37:12 +00:00
return
}
2019-01-03 16:21:23 +00:00
//close
2018-12-11 08:37:12 +00:00
func (s *Conn) Close() error {
2019-01-09 12:33:00 +00:00
return s.Conn.Close()
2018-12-11 08:37:12 +00:00
}
2019-01-03 16:21:23 +00:00
//write
2018-12-11 08:37:12 +00:00
func (s *Conn) Write(b []byte) (int, error) {
2019-01-09 12:33:00 +00:00
return s.Conn.Write(b)
2018-12-11 08:37:12 +00:00
}
2019-01-03 16:21:23 +00:00
//read
2018-12-11 08:37:12 +00:00
func (s *Conn) Read(b []byte) (int, error) {
2019-01-09 12:33:00 +00:00
return s.Conn.Read(b)
2018-12-11 08:37:12 +00:00
}
2019-01-03 16:21:23 +00:00
//write error
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteError() (int, error) {
2018-12-11 08:37:12 +00:00
return s.Write([]byte(RES_MSG))
2018-12-06 12:45:14 +00:00
}
2018-12-11 08:37:12 +00:00
2019-01-03 16:21:23 +00:00
//write sign flag
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteSign() (int, error) {
2018-12-11 08:37:12 +00:00
return s.Write([]byte(RES_SIGN))
2018-11-29 11:55:24 +00:00
}
2018-12-06 12:45:14 +00:00
2019-01-03 16:21:23 +00:00
//write main
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteMain() (int, error) {
2018-12-11 08:37:12 +00:00
return s.Write([]byte(WORK_MAIN))
}
2018-11-29 11:55:24 +00:00
2019-01-03 16:21:23 +00:00
//write chan
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteChan() (int, error) {
2018-12-11 08:37:12 +00:00
return s.Write([]byte(WORK_CHAN))
2018-11-29 11:55:24 +00:00
}
2019-01-03 16:21:23 +00:00
//write test
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteTest() (int, error) {
2018-12-11 08:37:12 +00:00
return s.Write([]byte(TEST_FLAG))
2018-11-29 11:55:24 +00:00
}
2019-01-02 17:44:45 +00:00
2019-01-06 17:52:54 +00:00
//write test
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteSuccess() (int, error) {
2019-01-06 17:52:54 +00:00
return s.Write([]byte(CONN_SUCCESS))
}
//write test
2019-01-09 12:33:00 +00:00
func (s *Conn) WriteFail() (int, error) {
2019-01-06 17:52:54 +00:00
return s.Write([]byte(CONN_ERROR))
}
2019-01-02 17:44:45 +00:00
//获取长度+内容
func GetLenBytes(buf []byte) (b []byte, err error) {
raw := bytes.NewBuffer([]byte{})
if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
return
}
if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
return
}
b = raw.Bytes()
return
}
//解析出长度
func GetLenByBytes(buf []byte) (int, error) {
nlen := binary.LittleEndian.Uint32(buf)
if nlen <= 0 {
return 0, errors.New("数据长度错误")
}
return int(nlen), nil
}