diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 144cd3fe..81c023ed 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -29,6 +29,8 @@ func (b *Buffer) Release() { } b.v = nil b.pool = nil + b.start = 0 + b.end = 0 } // Clear clears the content of the buffer, results an empty buffer with diff --git a/common/serial/typed_message.go b/common/serial/typed_message.go index b19e8b60..535487ca 100644 --- a/common/serial/typed_message.go +++ b/common/serial/typed_message.go @@ -23,11 +23,11 @@ func GetMessageType(message proto.Message) string { } func GetInstance(messageType string) (interface{}, error) { - mType := proto.MessageType(messageType).Elem() - if mType == nil { + mType := proto.MessageType(messageType) + if mType == nil || mType.Elem() == nil { return nil, errors.New("Unknown type: " + messageType) } - return reflect.New(mType).Interface(), nil + return reflect.New(mType.Elem()).Interface(), nil } func (v *TypedMessage) GetInstance() (interface{}, error) { diff --git a/proxy/handler_cache.go b/proxy/handler_cache.go index 774e62da..ef9a7d5d 100644 --- a/proxy/handler_cache.go +++ b/proxy/handler_cache.go @@ -4,6 +4,7 @@ import ( "v2ray.com/core/app" "v2ray.com/core/common" "v2ray.com/core/common/errors" + v2net "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" ) @@ -37,6 +38,8 @@ func CreateInboundHandler(name string, space app.Space, config interface{}, meta meta.StreamSettings = &internet.StreamConfig{ Network: creator.StreamCapability().Get(0), } + } else if meta.StreamSettings.Network == v2net.Network_Unknown { + meta.StreamSettings.Network = creator.StreamCapability().Get(0) } else { if !creator.StreamCapability().HasNetwork(meta.StreamSettings.Network) { return nil, errors.New("Proxy: Invalid network: " + meta.StreamSettings.Network.String()) @@ -55,6 +58,8 @@ func CreateOutboundHandler(name string, space app.Space, config interface{}, met meta.StreamSettings = &internet.StreamConfig{ Network: creator.StreamCapability().Get(0), } + } else if meta.StreamSettings.Network == v2net.Network_Unknown { + meta.StreamSettings.Network = creator.StreamCapability().Get(0) } else { if !creator.StreamCapability().HasNetwork(meta.StreamSettings.Network) { return nil, errors.New("Proxy: Invalid network: " + meta.StreamSettings.Network.String()) diff --git a/testing/scenarios/common.go b/testing/scenarios/common.go index eb73a5be..fe919f9d 100644 --- a/testing/scenarios/common.go +++ b/testing/scenarios/common.go @@ -1,9 +1,12 @@ package scenarios import ( - "github.com/golang/protobuf/proto" "sync/atomic" "time" + + "net" + + "github.com/golang/protobuf/proto" "v2ray.com/core" v2net "v2ray.com/core/common/net" ) @@ -16,6 +19,32 @@ func pickPort() v2net.Port { return v2net.Port(atomic.AddUint32(&port, 1)) } +func xor(b []byte) []byte { + r := make([]byte, len(b)) + for i, v := range b { + r[i] = v ^ 'c' + } + return r +} + +func readFrom(conn net.Conn, timeout time.Duration, length int) []byte { + b := make([]byte, 2048) + totalBytes := 0 + deadline := time.Now().Add(timeout) + conn.SetReadDeadline(deadline) + for totalBytes < length { + if time.Now().After(deadline) { + break + } + n, err := conn.Read(b[totalBytes:]) + if err != nil { + break + } + totalBytes += n + } + return b[:totalBytes] +} + func InitializeServerConfig(config *core.Config) error { err := BuildV2Ray() if err != nil { diff --git a/testing/scenarios/tls_test.go b/testing/scenarios/tls_test.go index 75c028d3..6b92f71f 100644 --- a/testing/scenarios/tls_test.go +++ b/testing/scenarios/tls_test.go @@ -1,15 +1,148 @@ package scenarios import ( + "net" + "path/filepath" + "testing" + + "io/ioutil" + "os" + + "time" + "v2ray.com/core" v2net "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" + "v2ray.com/core/common/serial" + "v2ray.com/core/common/uuid" + "v2ray.com/core/proxy/dokodemo" + "v2ray.com/core/proxy/freedom" + "v2ray.com/core/proxy/vmess" + "v2ray.com/core/proxy/vmess/inbound" + "v2ray.com/core/proxy/vmess/outbound" + "v2ray.com/core/testing/assert" + "v2ray.com/core/testing/servers/tcp" + "v2ray.com/core/transport/internet" + "v2ray.com/core/transport/internet/tls" ) -var clientConfig = &core.Config{ - Inbound: []*core.InboundConnectionConfig{ - { - PortRange: v2net.SinglePortRange(pickPort()), - ListenOn: v2net.NewIPOrDomain(v2net.LocalHostIP), - }, - }, +func mustReadFile(name string) []byte { + content, err := ioutil.ReadFile(name) + if err != nil { + panic(err) + } + return content +} + +func TestSimpleTLSConnection(t *testing.T) { + assert := assert.On(t) + + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + assert.Error(err).IsNil() + + userID := protocol.NewID(uuid.New()) + serverPort := pickPort() + serverConfig := &core.Config{ + Inbound: []*core.InboundConnectionConfig{ + { + PortRange: v2net.SinglePortRange(serverPort), + ListenOn: v2net.NewIPOrDomain(v2net.LocalHostIP), + Settings: serial.ToTypedMessage(&inbound.Config{ + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + }), + }, + }, + }), + StreamSettings: &internet.StreamConfig{ + SecurityType: serial.GetMessageType(&tls.Config{}), + SecuritySettings: []*serial.TypedMessage{ + serial.ToTypedMessage(&tls.Config{ + Certificate: []*tls.Certificate{ + { + Certificate: mustReadFile(filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "testing", "tls", "cert.pem")), + Key: mustReadFile(filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "testing", "tls", "key.pem")), + }, + }, + }), + }, + }, + }, + }, + Outbound: []*core.OutboundConnectionConfig{ + { + Settings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + clientPort := pickPort() + clientConfig := &core.Config{ + Inbound: []*core.InboundConnectionConfig{ + { + PortRange: v2net.SinglePortRange(clientPort), + ListenOn: v2net.NewIPOrDomain(v2net.LocalHostIP), + Settings: serial.ToTypedMessage(&dokodemo.Config{ + Address: v2net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + NetworkList: &v2net.NetworkList{ + Network: []v2net.Network{v2net.Network_TCP}, + }, + }), + }, + }, + Outbound: []*core.OutboundConnectionConfig{ + { + Settings: serial.ToTypedMessage(&outbound.Config{ + Receiver: []*protocol.ServerEndpoint{ + { + Address: v2net.NewIPOrDomain(v2net.LocalHostIP), + Port: uint32(serverPort), + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + }), + }, + }, + }, + }, + }), + StreamSettings: &internet.StreamConfig{ + SecurityType: serial.GetMessageType(&tls.Config{}), + SecuritySettings: []*serial.TypedMessage{ + serial.ToTypedMessage(&tls.Config{ + AllowInsecure: true, + }), + }, + }, + }, + }, + } + + assert.Error(InitializeServerConfig(serverConfig)).IsNil() + assert.Error(InitializeServerConfig(clientConfig)).IsNil() + + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: int(clientPort), + }) + + payload := "dokodemo request." + nBytes, err := conn.Write([]byte(payload)) + assert.Error(err).IsNil() + assert.Int(nBytes).Equals(len(payload)) + + conn.CloseWrite() + + response := readFrom(conn, time.Second*2, len(payload)) + assert.Bytes(response).Equals(xor([]byte(payload))) + conn.Close() + + CloseAllServers() } diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index a1073480..e29c350c 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -26,6 +26,7 @@ func (v *Config) BuildCertificates() []tls.Certificate { func (v *Config) GetTLSConfig() *tls.Config { config := &tls.Config{ ClientSessionCache: globalSessionCache, + NextProtos: []string{"http/2", "spdy/3"}, } if v == nil { return config