remove use of context.WithValue in transport

pull/1435/head
Darien Raymond 2018-11-21 14:54:40 +01:00
parent d2d0c69f17
commit 5279296f03
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
32 changed files with 212 additions and 297 deletions

View File

@ -125,14 +125,14 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
if !d.config.FollowRedirect { if !d.config.FollowRedirect {
writer = &buf.SequentialWriter{Writer: conn} writer = &buf.SequentialWriter{Writer: conn}
} else { } else {
tCtx := internet.ContextWithBindAddress(context.Background(), dest) sockopt := &internet.SocketConfig{
tCtx = internet.ContextWithStreamSettings(tCtx, &internet.MemoryStreamConfig{ Tproxy: internet.SocketConfig_TProxy,
ProtocolName: "udp", }
SocketSettings: &internet.SocketConfig{ if dest.Address.Family().IsIP() {
Tproxy: internet.SocketConfig_TProxy, sockopt.BindAddress = dest.Address.IP()
}, sockopt.BindPort = uint32(dest.Port)
}) }
tConn, err := internet.DialSystem(tCtx, net.DestinationFromAddr(conn.RemoteAddr())) tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,10 +22,10 @@ type Server struct {
} }
func (server *Server) Start() (net.Destination, error) { func (server *Server) Start() (net.Destination, error) {
return server.StartContext(context.Background()) return server.StartContext(context.Background(), nil)
} }
func (server *Server) StartContext(ctx context.Context) (net.Destination, error) { func (server *Server) StartContext(ctx context.Context, sockopt *internet.SocketConfig) (net.Destination, error) {
listenerAddr := server.Listen listenerAddr := server.Listen
if listenerAddr == nil { if listenerAddr == nil {
listenerAddr = net.LocalHostIP listenerAddr = net.LocalHostIP
@ -33,7 +33,7 @@ func (server *Server) StartContext(ctx context.Context) (net.Destination, error)
listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
IP: listenerAddr.IP(), IP: listenerAddr.IP(),
Port: int(server.Port), Port: int(server.Port),
}) }, sockopt)
if err != nil { if err != nil {
return net.Destination{}, err return net.Destination{}, err
} }

View File

@ -310,6 +310,8 @@ type SocketConfig struct {
// ReceiveOriginalDestAddress is for enabling IP_RECVORIGDSTADDR socket option. // ReceiveOriginalDestAddress is for enabling IP_RECVORIGDSTADDR socket option.
// This option is for UDP only. // This option is for UDP only.
ReceiveOriginalDestAddress bool `protobuf:"varint,4,opt,name=receive_original_dest_address,json=receiveOriginalDestAddress,proto3" json:"receive_original_dest_address,omitempty"` ReceiveOriginalDestAddress bool `protobuf:"varint,4,opt,name=receive_original_dest_address,json=receiveOriginalDestAddress,proto3" json:"receive_original_dest_address,omitempty"`
BindAddress []byte `protobuf:"bytes,5,opt,name=bind_address,json=bindAddress,proto3" json:"bind_address,omitempty"`
BindPort uint32 `protobuf:"varint,6,opt,name=bind_port,json=bindPort,proto3" json:"bind_port,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"` XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"` XXX_sizecache int32 `json:"-"`
@ -368,6 +370,20 @@ func (m *SocketConfig) GetReceiveOriginalDestAddress() bool {
return false return false
} }
func (m *SocketConfig) GetBindAddress() []byte {
if m != nil {
return m.BindAddress
}
return nil
}
func (m *SocketConfig) GetBindPort() uint32 {
if m != nil {
return m.BindPort
}
return 0
}
func init() { func init() {
proto.RegisterEnum("v2ray.core.transport.internet.TransportProtocol", TransportProtocol_name, TransportProtocol_value) proto.RegisterEnum("v2ray.core.transport.internet.TransportProtocol", TransportProtocol_name, TransportProtocol_value)
proto.RegisterEnum("v2ray.core.transport.internet.SocketConfig_TCPFastOpenState", SocketConfig_TCPFastOpenState_name, SocketConfig_TCPFastOpenState_value) proto.RegisterEnum("v2ray.core.transport.internet.SocketConfig_TCPFastOpenState", SocketConfig_TCPFastOpenState_name, SocketConfig_TCPFastOpenState_value)
@ -383,43 +399,45 @@ func init() {
} }
var fileDescriptor_91dbc815c3d97a05 = []byte{ var fileDescriptor_91dbc815c3d97a05 = []byte{
// 607 bytes of a gzipped FileDescriptorProto // 636 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x53, 0xdd, 0x6e, 0xd3, 0x4c, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x53, 0xdd, 0x6a, 0xdb, 0x4c,
0x10, 0xad, 0xed, 0x34, 0x4d, 0x27, 0x69, 0xea, 0xee, 0x55, 0x54, 0xa9, 0xfa, 0xfa, 0x05, 0x09, 0x10, 0x8d, 0x2c, 0xc7, 0x91, 0xc7, 0x8e, 0xa3, 0xec, 0x95, 0xc9, 0x47, 0xf8, 0x12, 0x17, 0x8a,
0x45, 0x20, 0xad, 0x2b, 0x23, 0xb8, 0xe2, 0xa6, 0x4d, 0x40, 0x54, 0xd0, 0xc6, 0x72, 0x0c, 0x48, 0x69, 0x41, 0x0a, 0x2a, 0xed, 0x55, 0x6f, 0x12, 0xbb, 0xa5, 0xa1, 0x4d, 0x2c, 0x64, 0xb5, 0x85,
0x95, 0x90, 0xb5, 0x75, 0x26, 0x91, 0xd5, 0xd8, 0x1b, 0xed, 0x2e, 0x15, 0x79, 0x25, 0xae, 0x79, 0x40, 0x11, 0x6b, 0x69, 0x6c, 0x44, 0x2c, 0xad, 0xd9, 0xdd, 0x86, 0xfa, 0x95, 0x0a, 0xbd, 0xeb,
0x08, 0x5e, 0x86, 0x77, 0x40, 0xbb, 0xfe, 0x21, 0x2a, 0x28, 0xb4, 0xe2, 0x6e, 0x3c, 0x73, 0xe6, 0x43, 0xf4, 0xb1, 0xca, 0xae, 0x7e, 0x6a, 0xd2, 0x92, 0x26, 0xf4, 0x6e, 0x34, 0x73, 0xe6, 0xcc,
0xcc, 0x39, 0x33, 0x5e, 0xa0, 0xb7, 0xbe, 0x60, 0x2b, 0x9a, 0xf0, 0xcc, 0x4b, 0xb8, 0x40, 0x4f, 0x9c, 0x39, 0x5a, 0x70, 0x6e, 0x3c, 0x4e, 0xd7, 0x4e, 0xcc, 0x32, 0x37, 0x66, 0x1c, 0x5d, 0xc9,
0x09, 0x96, 0xcb, 0x25, 0x17, 0xca, 0x4b, 0x73, 0x85, 0x22, 0x47, 0xe5, 0x25, 0x3c, 0x9f, 0xa5, 0x69, 0x2e, 0x56, 0x8c, 0x4b, 0x37, 0xcd, 0x25, 0xf2, 0x1c, 0xa5, 0x1b, 0xb3, 0x7c, 0x9e, 0x2e,
0x73, 0xba, 0x14, 0x5c, 0x71, 0x72, 0x54, 0xe1, 0x05, 0xd2, 0x1a, 0x4b, 0x2b, 0xec, 0xe1, 0xc9, 0x9c, 0x15, 0x67, 0x92, 0x91, 0xc3, 0x0a, 0xcf, 0xd1, 0xa9, 0xb1, 0x4e, 0x85, 0x3d, 0x38, 0xb9,
0x1d, 0xba, 0x84, 0x67, 0x19, 0xcf, 0x3d, 0x89, 0x22, 0x65, 0x0b, 0x4f, 0xad, 0x96, 0x38, 0x8d, 0x45, 0x17, 0xb3, 0x2c, 0x63, 0xb9, 0x2b, 0x90, 0xa7, 0x74, 0xe9, 0xca, 0xf5, 0x0a, 0x93, 0x28,
0x33, 0x94, 0x92, 0xcd, 0xb1, 0x20, 0xec, 0x7f, 0xb7, 0x60, 0x3f, 0xaa, 0x88, 0x86, 0x66, 0x14, 0x43, 0x21, 0xe8, 0x02, 0x0b, 0xc2, 0xc1, 0x0f, 0x03, 0xf6, 0xc2, 0x8a, 0x68, 0xa4, 0x47, 0x91,
0x79, 0x07, 0x2d, 0x53, 0x4c, 0xf8, 0xa2, 0x67, 0x1d, 0x5b, 0x83, 0xae, 0x7f, 0x42, 0x37, 0xce, 0x77, 0x60, 0xe9, 0x62, 0xcc, 0x96, 0x7d, 0xe3, 0xc8, 0x18, 0xf6, 0xbc, 0x13, 0xe7, 0xce, 0xb9,
0xa5, 0x35, 0x43, 0x50, 0xf6, 0x85, 0x35, 0x03, 0x79, 0x04, 0x7b, 0x55, 0x1c, 0xe7, 0x2c, 0xc3, 0x4e, 0xcd, 0xe0, 0x97, 0x7d, 0x41, 0xcd, 0x40, 0x1e, 0xc1, 0x6e, 0x15, 0x47, 0x39, 0xcd, 0xb0,
0x9e, 0x73, 0x6c, 0x0d, 0x76, 0xc3, 0x4e, 0x95, 0xbc, 0x64, 0x19, 0x92, 0x33, 0x68, 0x49, 0x54, 0x6f, 0x1e, 0x19, 0xc3, 0x76, 0xd0, 0xad, 0x92, 0x97, 0x34, 0x43, 0x72, 0x06, 0x96, 0x40, 0x29,
0x2a, 0xcd, 0xe7, 0xb2, 0x67, 0x1f, 0x5b, 0x83, 0xb6, 0xff, 0x78, 0x7d, 0x64, 0xe1, 0x83, 0x16, 0xd3, 0x7c, 0x21, 0xfa, 0x8d, 0x23, 0x63, 0xd8, 0xf1, 0x1e, 0x6f, 0x8e, 0x2c, 0x74, 0x38, 0x85,
0x3e, 0x68, 0xa4, 0x7d, 0x5c, 0x14, 0x36, 0xc2, 0xba, 0xaf, 0xff, 0xcd, 0x81, 0xce, 0x44, 0x09, 0x0e, 0x27, 0x54, 0x3a, 0x2e, 0x0a, 0x19, 0x41, 0xdd, 0x37, 0xf8, 0x6e, 0x42, 0x77, 0x2a, 0x39,
0x64, 0x59, 0xe9, 0x23, 0xf8, 0x77, 0x1f, 0x67, 0x76, 0xcf, 0xda, 0xe4, 0x65, 0xfb, 0x0f, 0x5e, 0xd2, 0xac, 0xd4, 0xe1, 0xff, 0xbb, 0x8e, 0xb3, 0x46, 0xdf, 0xb8, 0x4b, 0xcb, 0xf6, 0x1f, 0xb4,
0x3e, 0x01, 0xa9, 0xa9, 0xe3, 0x35, 0x57, 0xce, 0xa0, 0xed, 0xd3, 0xfb, 0x0a, 0x28, 0x2c, 0x84, 0x7c, 0x02, 0x52, 0x53, 0x47, 0x1b, 0xaa, 0xcc, 0x61, 0xc7, 0x73, 0xee, 0xbb, 0x40, 0x21, 0x21,
0x07, 0x35, 0x66, 0x52, 0x12, 0x69, 0x0d, 0x12, 0x93, 0xcf, 0x22, 0x55, 0xab, 0x58, 0x5f, 0xb4, 0xd8, 0xaf, 0x31, 0xd3, 0x92, 0x48, 0xed, 0x20, 0x30, 0xfe, 0xcc, 0x53, 0xb9, 0x8e, 0x94, 0xa3,
0xda, 0x67, 0x95, 0xd4, 0xdb, 0x21, 0x13, 0x38, 0xa8, 0x41, 0xb5, 0x84, 0x86, 0x91, 0x70, 0xdf, 0xd5, 0x3d, 0xab, 0xa4, 0xba, 0x0e, 0x99, 0xc2, 0x7e, 0x0d, 0xaa, 0x57, 0x68, 0xea, 0x15, 0xee,
0xc5, 0xba, 0x15, 0x41, 0x3d, 0x39, 0x82, 0x7d, 0xc9, 0x93, 0x1b, 0x5c, 0x73, 0xd5, 0x34, 0xb7, 0x7b, 0x58, 0xbb, 0x22, 0xa8, 0x27, 0x87, 0xb0, 0x27, 0x58, 0x7c, 0x8d, 0x1b, 0xaa, 0x5a, 0xda,
0x7a, 0xfa, 0x17, 0x57, 0x13, 0xd3, 0x55, 0x5a, 0xea, 0x16, 0x1c, 0x15, 0x6b, 0xff, 0x3f, 0x68, 0xab, 0xa7, 0x7f, 0x51, 0x35, 0xd5, 0x5d, 0xa5, 0xa4, 0x5e, 0xc1, 0x51, 0xb1, 0x0e, 0xfe, 0x87,
0x07, 0x82, 0x7f, 0x59, 0x95, 0x47, 0x73, 0xc1, 0x51, 0x6c, 0x6e, 0xee, 0xb5, 0x1b, 0xea, 0xb0, 0x8e, 0xcf, 0xd9, 0x97, 0x75, 0x69, 0x9a, 0x0d, 0xa6, 0xa4, 0x0b, 0xed, 0x57, 0x3b, 0x50, 0xe1,
0xff, 0xc3, 0x86, 0xce, 0x3a, 0x03, 0x21, 0xd0, 0xc8, 0x98, 0xb8, 0x31, 0x98, 0xed, 0xd0, 0xc4, 0xe0, 0x9b, 0xf2, 0x75, 0x83, 0x81, 0x10, 0x68, 0x66, 0x94, 0x5f, 0x6b, 0xcc, 0x76, 0xa0, 0x63,
0xe4, 0x12, 0x1c, 0x35, 0xe3, 0xe6, 0xdf, 0xe9, 0xfa, 0x2f, 0x1f, 0xa0, 0x87, 0x46, 0xc3, 0xe0, 0x72, 0x09, 0xa6, 0x9c, 0x33, 0xfd, 0xef, 0xf4, 0xbc, 0x97, 0x0f, 0xd8, 0xc7, 0x09, 0x47, 0xfe,
0x35, 0x93, 0x6a, 0xbc, 0xc4, 0x7c, 0xa2, 0x98, 0xc2, 0x50, 0x13, 0x91, 0x4b, 0x68, 0xaa, 0xa5, 0x6b, 0x2a, 0xe4, 0x64, 0x85, 0xf9, 0x54, 0x52, 0x89, 0x81, 0x22, 0x22, 0x97, 0xd0, 0x92, 0x2b,
0x96, 0x65, 0xd6, 0xdb, 0xf5, 0x5f, 0x3c, 0x88, 0xd2, 0x18, 0xba, 0xe0, 0x53, 0x0c, 0x4b, 0x16, 0xb5, 0x96, 0x3e, 0x6f, 0xcf, 0x7b, 0xf1, 0x20, 0x4a, 0x2d, 0xe8, 0x82, 0x25, 0x18, 0x94, 0x2c,
0x72, 0x0a, 0x47, 0x02, 0x13, 0x4c, 0x6f, 0x31, 0xe6, 0x22, 0x9d, 0xa7, 0x39, 0x5b, 0xc4, 0x53, 0xe4, 0x14, 0x0e, 0x39, 0xc6, 0x98, 0xde, 0x60, 0xc4, 0x78, 0xba, 0x48, 0x73, 0xba, 0x8c, 0x12,
0x94, 0x2a, 0x66, 0xd3, 0xa9, 0x40, 0xa9, 0x8f, 0x63, 0x0d, 0x5a, 0xe1, 0x61, 0x09, 0x1a, 0x97, 0x14, 0x32, 0xa2, 0x49, 0xc2, 0x51, 0x28, 0x73, 0x8c, 0xa1, 0x15, 0x1c, 0x94, 0xa0, 0x49, 0x89,
0x98, 0x11, 0x4a, 0x75, 0x5a, 0x20, 0xfa, 0xcf, 0xc1, 0xbd, 0xab, 0x95, 0xb4, 0xa0, 0x71, 0x2a, 0x19, 0xa3, 0x90, 0xa7, 0x05, 0x82, 0x1c, 0x43, 0x77, 0x96, 0xe6, 0x49, 0xdd, 0xa1, 0xfe, 0xbd,
0xcf, 0xa5, 0xbb, 0x45, 0x00, 0x9a, 0xaf, 0x72, 0x76, 0xbd, 0x40, 0xd7, 0x22, 0x6d, 0xd8, 0x19, 0x6e, 0xd0, 0x51, 0xb9, 0x0a, 0xf2, 0x1f, 0xb4, 0x35, 0x44, 0xed, 0xa6, 0xbd, 0xd9, 0x0d, 0x2c,
0xa5, 0xd2, 0x7c, 0xd8, 0x7d, 0x0f, 0xe0, 0x97, 0x1e, 0xb2, 0x03, 0xce, 0x78, 0x36, 0x2b, 0xf0, 0x95, 0xf0, 0x19, 0x97, 0x83, 0xe7, 0x60, 0xdf, 0xd6, 0x4a, 0x2c, 0x68, 0x9e, 0x8a, 0x73, 0x61,
0x45, 0xda, 0xb5, 0x48, 0x07, 0x5a, 0x21, 0x4e, 0x53, 0x81, 0x89, 0x72, 0xed, 0x27, 0x57, 0x70, 0x6f, 0x11, 0x80, 0xd6, 0xab, 0x9c, 0xce, 0x96, 0x68, 0x1b, 0xa4, 0x03, 0x3b, 0xe3, 0x54, 0xe8,
0xf0, 0xdb, 0x3b, 0xd0, 0x7d, 0xd1, 0x30, 0x70, 0xb7, 0x74, 0xf0, 0x7e, 0x14, 0xb8, 0x96, 0x1e, 0x8f, 0xc6, 0xc0, 0x05, 0xf8, 0xa5, 0x87, 0xec, 0x80, 0x39, 0x99, 0xcf, 0x0b, 0x7c, 0x91, 0xb6,
0x7d, 0xf1, 0x76, 0x18, 0xb8, 0x36, 0xd9, 0x83, 0xdd, 0x8f, 0x78, 0x5d, 0x6c, 0xc0, 0x75, 0x74, 0x0d, 0xd2, 0x05, 0x2b, 0xc0, 0x24, 0xe5, 0x18, 0x4b, 0xbb, 0xf1, 0xe4, 0x0a, 0xf6, 0x7f, 0x7b,
0xe1, 0x4d, 0x14, 0x05, 0x6e, 0x83, 0xb8, 0xd0, 0x19, 0xf1, 0x8c, 0xa5, 0x79, 0x59, 0xdb, 0x3e, 0x47, 0xaa, 0x2f, 0x1c, 0xf9, 0xf6, 0x96, 0x0a, 0xde, 0x8f, 0x7d, 0xdb, 0x50, 0xa3, 0x2f, 0xde,
0x1b, 0xc3, 0xff, 0x09, 0xcf, 0x36, 0xef, 0x32, 0xb0, 0xae, 0x5a, 0x55, 0xfc, 0xd5, 0x3e, 0xfa, 0x8e, 0x7c, 0xbb, 0x41, 0x76, 0xa1, 0xfd, 0x11, 0x67, 0xc5, 0x05, 0x6d, 0x53, 0x15, 0xde, 0x84,
0xe0, 0x87, 0x6c, 0x45, 0x87, 0x1a, 0x5b, 0xcb, 0xa2, 0xe7, 0x65, 0xfd, 0xba, 0x69, 0x9e, 0xde, 0xa1, 0x6f, 0x37, 0x89, 0x0d, 0xdd, 0x31, 0xcb, 0x68, 0x9a, 0x97, 0xb5, 0xed, 0xb3, 0x09, 0x1c,
0xb3, 0x9f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xf2, 0x99, 0xd5, 0x37, 0x49, 0x05, 0x00, 0x00, 0xc7, 0x2c, 0xbb, 0xdb, 0x0b, 0xdf, 0xb8, 0xb2, 0xaa, 0xf8, 0x6b, 0xe3, 0xf0, 0x83, 0x17, 0xd0,
0xb5, 0x33, 0x52, 0xd8, 0x7a, 0x2d, 0xe7, 0xbc, 0xac, 0xcf, 0x5a, 0xfa, 0xe9, 0x3e, 0xfb, 0x19,
0x00, 0x00, 0xff, 0xff, 0xce, 0xe9, 0xc8, 0x20, 0x89, 0x05, 0x00, 0x00,
} }

View File

@ -83,4 +83,8 @@ message SocketConfig {
// ReceiveOriginalDestAddress is for enabling IP_RECVORIGDSTADDR socket option. // ReceiveOriginalDestAddress is for enabling IP_RECVORIGDSTADDR socket option.
// This option is for UDP only. // This option is for UDP only.
bool receive_original_dest_address = 4; bool receive_original_dest_address = 4;
bytes bind_address = 5;
uint32 bind_port = 6;
} }

View File

@ -1,37 +0,0 @@
package internet
import (
"context"
"v2ray.com/core/common/net"
)
type key int
const (
streamSettingsKey key = iota
bindAddrKey
)
func ContextWithStreamSettings(ctx context.Context, streamSettings *MemoryStreamConfig) context.Context {
return context.WithValue(ctx, streamSettingsKey, streamSettings)
}
func StreamSettingsFromContext(ctx context.Context) *MemoryStreamConfig {
ss := ctx.Value(streamSettingsKey)
if ss == nil {
return nil
}
return ss.(*MemoryStreamConfig)
}
func ContextWithBindAddress(ctx context.Context, dest net.Destination) context.Context {
return context.WithValue(ctx, bindAddrKey, dest)
}
func BindAddressFromContext(ctx context.Context) net.Destination {
if addr, ok := ctx.Value(bindAddrKey).(net.Destination); ok {
return addr
}
return net.Destination{}
}

View File

@ -17,7 +17,7 @@ type Dialer interface {
} }
// dialFunc is an interface to dial network connection to a specific destination. // dialFunc is an interface to dial network connection to a specific destination.
type dialFunc func(ctx context.Context, dest net.Destination) (Connection, error) type dialFunc func(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (Connection, error)
var ( var (
transportDialerCache = make(map[string]dialFunc) transportDialerCache = make(map[string]dialFunc)
@ -33,16 +33,14 @@ func RegisterTransportDialer(protocol string, dialer dialFunc) error {
} }
// Dial dials a internet connection towards the given destination. // Dial dials a internet connection towards the given destination.
func Dial(ctx context.Context, dest net.Destination) (Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (Connection, error) {
if dest.Network == net.Network_TCP { if dest.Network == net.Network_TCP {
streamSettings := StreamSettingsFromContext(ctx)
if streamSettings == nil { if streamSettings == nil {
s, err := ToMemoryStreamConfig(nil) s, err := ToMemoryStreamConfig(nil)
if err != nil { if err != nil {
return nil, newError("failed to create default stream settings").Base(err) return nil, newError("failed to create default stream settings").Base(err)
} }
streamSettings = s streamSettings = s
ctx = ContextWithStreamSettings(ctx, streamSettings)
} }
protocol := streamSettings.ProtocolName protocol := streamSettings.ProtocolName
@ -50,7 +48,7 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) {
if dialer == nil { if dialer == nil {
return nil, newError(protocol, " dialer not registered").AtError() return nil, newError(protocol, " dialer not registered").AtError()
} }
return dialer(ctx, dest) return dialer(ctx, dest, streamSettings)
} }
if dest.Network == net.Network_UDP { if dest.Network == net.Network_UDP {
@ -58,17 +56,17 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) {
if udpDialer == nil { if udpDialer == nil {
return nil, newError("UDP dialer not registered").AtError() return nil, newError("UDP dialer not registered").AtError()
} }
return udpDialer(ctx, dest) return udpDialer(ctx, dest, streamSettings)
} }
return nil, newError("unknown network ", dest.Network) return nil, newError("unknown network ", dest.Network)
} }
// DialSystem calls system dialer to create a network connection. // DialSystem calls system dialer to create a network connection.
func DialSystem(ctx context.Context, dest net.Destination) (net.Conn, error) { func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
var src net.Address var src net.Address
if outbound := session.OutboundFromContext(ctx); outbound != nil { if outbound := session.OutboundFromContext(ctx); outbound != nil {
src = outbound.Gateway src = outbound.Gateway
} }
return effectiveSystemDialer.Dial(ctx, src, dest) return effectiveSystemDialer.Dial(ctx, src, dest, sockopt)
} }

View File

@ -18,7 +18,7 @@ func TestDialWithLocalAddr(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
defer server.Close() defer server.Close()
conn, err := DialSystem(context.Background(), net.TCPDestination(net.LocalHostIP, dest.Port)) conn, err := DialSystem(context.Background(), net.TCPDestination(net.LocalHostIP, dest.Port), nil)
assert(err, IsNil) assert(err, IsNil)
assert(conn.RemoteAddr().String(), Equals, "127.0.0.1:"+dest.Port.String()) assert(conn.RemoteAddr().String(), Equals, "127.0.0.1:"+dest.Port.String())
conn.Close() conn.Close()

View File

@ -12,20 +12,8 @@ import (
"v2ray.com/core/transport/internet/tls" "v2ray.com/core/transport/internet/tls"
) )
func getSettingsFromContext(ctx context.Context) *Config { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
rawSettings := internet.StreamSettingsFromContext(ctx) settings := streamSettings.ProtocolSettings.(*Config)
if rawSettings == nil {
return nil
}
return rawSettings.ProtocolSettings.(*Config)
}
func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) {
settings := getSettingsFromContext(ctx)
if settings == nil {
return nil, newError("domain socket settings is not specified.").AtError()
}
addr, err := settings.GetUnixAddr() addr, err := settings.GetUnixAddr()
if err != nil { if err != nil {
return nil, err return nil, err
@ -36,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error
return nil, newError("failed to dial unix: ", settings.Path).Base(err).AtWarning() return nil, newError("failed to dial unix: ", settings.Path).Base(err).AtWarning()
} }
if config := tls.ConfigFromContext(ctx); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
return tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest))), nil return tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest))), nil
} }

View File

@ -25,12 +25,8 @@ type Listener struct {
locker *fileLocker locker *fileLocker
} }
func Listen(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
settings := getSettingsFromContext(ctx) settings := streamSettings.ProtocolSettings.(*Config)
if settings == nil {
return nil, newError("domain socket settings not specified.")
}
addr, err := settings.GetUnixAddr() addr, err := settings.GetUnixAddr()
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,7 +54,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int
} }
} }
if config := tls.ConfigFromContext(ctx); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
ln.tlsConfig = config.GetTLSConfig() ln.tlsConfig = config.GetTLSConfig()
} }

View File

@ -18,13 +18,14 @@ import (
func TestListen(t *testing.T) { func TestListen(t *testing.T) {
assert := With(t) assert := With(t)
ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ ctx := context.Background()
streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "domainsocket", ProtocolName: "domainsocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "/tmp/ts3", Path: "/tmp/ts3",
}, },
}) }
listener, err := Listen(ctx, nil, net.Port(0), func(conn internet.Connection) { listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn internet.Connection) {
defer conn.Close() defer conn.Close()
b := buf.New() b := buf.New()
@ -36,7 +37,7 @@ func TestListen(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
defer listener.Close() defer listener.Close()
conn, err := Dial(ctx, net.Destination{}) conn, err := Dial(ctx, net.Destination{}, streamSettings)
assert(err, IsNil) assert(err, IsNil)
defer conn.Close() defer conn.Close()
@ -56,14 +57,15 @@ func TestListenAbstract(t *testing.T) {
assert := With(t) assert := With(t)
ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ ctx := context.Background()
streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "domainsocket", ProtocolName: "domainsocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "/tmp/ts3", Path: "/tmp/ts3",
Abstract: true, Abstract: true,
}, },
}) }
listener, err := Listen(ctx, nil, net.Port(0), func(conn internet.Connection) { listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn internet.Connection) {
defer conn.Close() defer conn.Close()
b := buf.New() b := buf.New()
@ -75,7 +77,7 @@ func TestListenAbstract(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
defer listener.Close() defer listener.Close()
conn, err := Dial(ctx, net.Destination{}) conn, err := Dial(ctx, net.Destination{}, streamSettings)
assert(err, IsNil) assert(err, IsNil)
defer conn.Close() defer conn.Close()

View File

@ -21,7 +21,7 @@ var (
globalDailerAccess sync.Mutex globalDailerAccess sync.Mutex
) )
func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, error) { func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.Config) (*http.Client, error) {
globalDailerAccess.Lock() globalDailerAccess.Lock()
defer globalDailerAccess.Unlock() defer globalDailerAccess.Unlock()
@ -33,11 +33,6 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err
return client, nil return client, nil
} }
config := tls.ConfigFromContext(ctx)
if config == nil {
return nil, newError("TLS must be enabled for http transport.").AtWarning()
}
transport := &http2.Transport{ transport := &http2.Transport{
DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr) rawHost, rawPort, err := net.SplitHostPort(addr)
@ -53,13 +48,13 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err
} }
address := net.ParseAddress(rawHost) address := net.ParseAddress(rawHost)
pconn, err := internet.DialSystem(context.Background(), net.TCPDestination(address, port)) pconn, err := internet.DialSystem(context.Background(), net.TCPDestination(address, port), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return gotls.Client(pconn, tlsConfig), nil return gotls.Client(pconn, tlsConfig), nil
}, },
TLSClientConfig: config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")), TLSClientConfig: tlsSettings.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")),
} }
client := &http.Client{ client := &http.Client{
@ -71,14 +66,13 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err
} }
// Dial dials a new TCP connection to the given destination. // Dial dials a new TCP connection to the given destination.
func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
rawSettings := internet.StreamSettingsFromContext(ctx) httpSettings := streamSettings.ProtocolSettings.(*Config)
httpSettings, ok := rawSettings.ProtocolSettings.(*Config) tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
if !ok { if tlsConfig == nil {
return nil, newError("HTTP config is not set.").AtError() return nil, newError("TLS must be enabled for http transport.").AtWarning()
} }
client, err := getHTTPClient(ctx, dest, tlsConfig)
client, err := getHTTPClient(ctx, dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,16 +22,14 @@ func TestHTTPConnection(t *testing.T) {
port := tcp.PickPort() port := tcp.PickPort()
lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
ProtocolName: "http", ProtocolName: "http",
ProtocolSettings: &Config{}, ProtocolSettings: &Config{},
SecurityType: "tls", SecurityType: "tls",
SecuritySettings: &tls.Config{ SecuritySettings: &tls.Config{
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.v2ray.com")))}, Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.v2ray.com")))},
}, },
}) }, func(conn internet.Connection) {
listener, err := Listen(lctx, net.LocalHostIP, port, func(conn internet.Connection) {
go func() { go func() {
defer conn.Close() defer conn.Close()
@ -54,7 +52,8 @@ func TestHTTPConnection(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
dctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ dctx := context.Background()
conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolName: "http", ProtocolName: "http",
ProtocolSettings: &Config{}, ProtocolSettings: &Config{},
SecurityType: "tls", SecurityType: "tls",
@ -63,7 +62,6 @@ func TestHTTPConnection(t *testing.T) {
AllowInsecure: true, AllowInsecure: true,
}, },
}) })
conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port))
assert(err, IsNil) assert(err, IsNil)
defer conn.Close() defer conn.Close()

View File

@ -88,13 +88,8 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
<-done.Wait() <-done.Wait()
} }
func Listen(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
rawSettings := internet.StreamSettingsFromContext(ctx) httpSettings := streamSettings.ProtocolSettings.(*Config)
httpSettings, ok := rawSettings.ProtocolSettings.(*Config)
if !ok {
return nil, newError("HTTP config is not set.").AtError()
}
listener := &Listener{ listener := &Listener{
handler: handler, handler: handler,
local: &net.TCPAddr{ local: &net.TCPAddr{
@ -104,7 +99,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int
config: *httpSettings, config: *httpSettings,
} }
config := tls.ConfigFromContext(ctx) config := tls.ConfigFromStreamSettings(streamSettings)
if config == nil { if config == nil {
return nil, newError("TLS must be enabled for http transport.").AtWarning() return nil, newError("TLS must be enabled for http transport.").AtWarning()
} }
@ -120,7 +115,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int
tcpListener, err := internet.ListenSystem(ctx, &net.TCPAddr{ tcpListener, err := internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
}) }, streamSettings.SocketSettings)
if err != nil { if err != nil {
newError("failed to listen on", address, ":", port).Base(err).WriteToLog(session.ExportIDToError(ctx)) newError("failed to listen on", address, ":", port).Base(err).WriteToLog(session.ExportIDToError(ctx))
return return

View File

@ -45,16 +45,16 @@ func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn
} }
} }
func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, error) { func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
dest.Network = net.Network_UDP dest.Network = net.Network_UDP
newError("dialing mKCP to ", dest).WriteToLog() newError("dialing mKCP to ", dest).WriteToLog()
rawConn, err := internet.DialSystem(ctx, dest) rawConn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil { if err != nil {
return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err) return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
} }
kcpSettings := internet.StreamSettingsFromContext(ctx).ProtocolSettings.(*Config) kcpSettings := streamSettings.ProtocolSettings.(*Config)
header, err := kcpSettings.GetPackerHeader() header, err := kcpSettings.GetPackerHeader()
if err != nil { if err != nil {
@ -85,7 +85,7 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
var iConn internet.Connection = session var iConn internet.Connection = session
if config := v2tls.ConfigFromContext(ctx); config != nil { if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
tlsConn := tls.Client(iConn, config.GetTLSConfig(v2tls.WithDestination(dest))) tlsConn := tls.Client(iConn, config.GetTLSConfig(v2tls.WithDestination(dest)))
iConn = tlsConn iConn = tlsConn
} }

View File

@ -17,11 +17,10 @@ import (
func TestDialAndListen(t *testing.T) { func TestDialAndListen(t *testing.T) {
assert := With(t) assert := With(t)
lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ listerner, err := NewListener(context.Background(), net.LocalHostIP, net.Port(0), &internet.MemoryStreamConfig{
ProtocolName: "mkcp", ProtocolName: "mkcp",
ProtocolSettings: &Config{}, ProtocolSettings: &Config{},
}) }, func(conn internet.Connection) {
listerner, err := NewListener(lctx, net.LocalHostIP, net.Port(0), func(conn internet.Connection) {
go func(c internet.Connection) { go func(c internet.Connection) {
payload := make([]byte, 4096) payload := make([]byte, 4096)
for { for {
@ -40,13 +39,12 @@ func TestDialAndListen(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
port := net.Port(listerner.Addr().(*net.UDPAddr).Port) port := net.Port(listerner.Addr().(*net.UDPAddr).Port)
ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{
ProtocolName: "mkcp",
ProtocolSettings: &Config{},
})
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
clientConn, err := DialKCP(ctx, net.UDPDestination(net.LocalHostIP, port)) clientConn, err := DialKCP(context.Background(), net.UDPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolName: "mkcp",
ProtocolSettings: &Config{},
})
assert(err, IsNil) assert(err, IsNil)
wg.Add(1) wg.Add(1)

View File

@ -33,10 +33,8 @@ type Listener struct {
addConn internet.ConnHandler addConn internet.ConnHandler
} }
func NewListener(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (*Listener, error) { func NewListener(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (*Listener, error) {
networkSettings := internet.StreamSettingsFromContext(ctx) kcpSettings := streamSettings.ProtocolSettings.(*Config)
kcpSettings := networkSettings.ProtocolSettings.(*Config)
header, err := kcpSettings.GetPackerHeader() header, err := kcpSettings.GetPackerHeader()
if err != nil { if err != nil {
return nil, newError("failed to create packet header").Base(err).AtError() return nil, newError("failed to create packet header").Base(err).AtError()
@ -57,11 +55,11 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon
addConn: addConn, addConn: addConn,
} }
if config := v2tls.ConfigFromContext(ctx); config != nil { if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
l.tlsConfig = config.GetTLSConfig() l.tlsConfig = config.GetTLSConfig()
} }
hub, err := udp.ListenUDP(ctx, address, port, udp.HubCapacity(1024)) hub, err := udp.ListenUDP(ctx, address, port, streamSettings, udp.HubCapacity(1024))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,8 +187,8 @@ func (w *Writer) Close() error {
return nil return nil
} }
func ListenKCP(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (internet.Listener, error) { func ListenKCP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
return NewListener(ctx, address, port, addConn) return NewListener(ctx, address, port, streamSettings, addConn)
} }
func init() { func init() {

View File

@ -2,8 +2,6 @@ package internet
import ( import (
"syscall" "syscall"
"v2ray.com/core/common/net"
) )
const ( const (
@ -49,6 +47,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig)
return nil return nil
} }
func bindAddr(fd uintptr, address net.Address, port net.Port) error { func bindAddr(fd uintptr, address []byte, port uint32) error {
return nil return nil
} }

View File

@ -1,9 +1,8 @@
package internet package internet
import ( import (
"net"
"syscall" "syscall"
"v2ray.com/core/common/net"
) )
const ( const (
@ -13,7 +12,7 @@ const (
TCP_FASTOPEN_CONNECT = 30 TCP_FASTOPEN_CONNECT = 30
) )
func bindAddr(fd uintptr, address net.Address, port net.Port) error { func bindAddr(fd uintptr, ip []byte, port uint32) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err != nil { if err != nil {
return newError("failed to set resuse_addr").Base(err).AtWarning() return newError("failed to set resuse_addr").Base(err).AtWarning()
@ -21,21 +20,21 @@ func bindAddr(fd uintptr, address net.Address, port net.Port) error {
var sockaddr syscall.Sockaddr var sockaddr syscall.Sockaddr
switch address.Family() { switch len(ip) {
case net.AddressFamilyIPv4: case net.IPv4len:
a4 := &syscall.SockaddrInet4{ a4 := &syscall.SockaddrInet4{
Port: int(port), Port: int(port),
} }
copy(a4.Addr[:], address.IP()) copy(a4.Addr[:], ip)
sockaddr = a4 sockaddr = a4
case net.AddressFamilyIPv6: case net.IPv6len:
a6 := &syscall.SockaddrInet6{ a6 := &syscall.SockaddrInet6{
Port: int(port), Port: int(port),
} }
copy(a6.Addr[:], address.IP()) copy(a6.Addr[:], ip)
sockaddr = a6 sockaddr = a6
default: default:
return newError("unsupported address family: ", address.Family()) return newError("unexpected length of ip")
} }
return syscall.Bind(int(fd), sockaddr) return syscall.Bind(int(fd), sockaddr)

View File

@ -2,8 +2,6 @@
package internet package internet
import "v2ray.com/core/common/net"
func applyOutboundSocketOptions(network string, address string, fd uintptr, config *SocketConfig) error { func applyOutboundSocketOptions(network string, address string, fd uintptr, config *SocketConfig) error {
return nil return nil
} }
@ -12,6 +10,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig)
return nil return nil
} }
func bindAddr(fd uintptr, address net.Address, port net.Port) error { func bindAddr(fd uintptr, ip []byte, port uint32) error {
return nil return nil
} }

View File

@ -17,22 +17,15 @@ func TestTCPFastOpen(t *testing.T) {
return b return b
}, },
} }
dest, err := tcpServer.StartContext(ContextWithStreamSettings(context.Background(), &MemoryStreamConfig{ dest, err := tcpServer.StartContext(context.Background(), &SocketConfig{Tfo: SocketConfig_Enable})
SocketSettings: &SocketConfig{
Tfo: SocketConfig_Enable,
},
}))
common.Must(err) common.Must(err)
defer tcpServer.Close() defer tcpServer.Close()
ctx := context.Background() ctx := context.Background()
ctx = ContextWithStreamSettings(ctx, &MemoryStreamConfig{
SocketSettings: &SocketConfig{
Tfo: SocketConfig_Enable,
},
})
dialer := DefaultSystemDialer{} dialer := DefaultSystemDialer{}
conn, err := dialer.Dial(ctx, nil, dest) conn, err := dialer.Dial(ctx, nil, dest, &SocketConfig{
Tfo: SocketConfig_Enable,
})
common.Must(err) common.Must(err)
defer conn.Close() defer conn.Close()

View File

@ -2,8 +2,6 @@ package internet
import ( import (
"syscall" "syscall"
"v2ray.com/core/common/net"
) )
const ( const (
@ -45,6 +43,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig)
return nil return nil
} }
func bindAddr(fd uintptr, address net.Address, port net.Port) error { func bindAddr(fd uintptr, ip []byte, port uint32) error {
return nil return nil
} }

View File

@ -14,38 +14,27 @@ var (
) )
type SystemDialer interface { type SystemDialer interface {
Dial(ctx context.Context, source net.Address, destination net.Destination) (net.Conn, error) Dial(ctx context.Context, source net.Address, destination net.Destination, sockopt *SocketConfig) (net.Conn, error)
} }
type DefaultSystemDialer struct { type DefaultSystemDialer struct {
} }
func getSocketSettings(ctx context.Context) *SocketConfig { func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
streamSettings := StreamSettingsFromContext(ctx)
if streamSettings != nil && streamSettings.SocketSettings != nil {
return streamSettings.SocketSettings
}
return nil
}
func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: time.Second * 60, Timeout: time.Second * 60,
DualStack: true, DualStack: true,
} }
sockopts := getSocketSettings(ctx) if sockopt != nil {
if sockopts != nil {
bindAddress := BindAddressFromContext(ctx)
dialer.Control = func(network, address string, c syscall.RawConn) error { dialer.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
if err := applyOutboundSocketOptions(network, address, fd, sockopts); err != nil { if err := applyOutboundSocketOptions(network, address, fd, sockopt); err != nil {
newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx)) newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx))
} }
if dest.Network == net.Network_UDP && bindAddress.IsValid() { if dest.Network == net.Network_UDP && len(sockopt.BindAddress) > 0 && sockopt.BindPort > 0 {
if err := bindAddr(fd, bindAddress.Address, bindAddress.Port); err != nil { if err := bindAddr(fd, sockopt.BindAddress, sockopt.BindPort); err != nil {
newError("failed to bind source address to ", bindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx)) newError("failed to bind source address to ", sockopt.BindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx))
} }
} }
}) })
@ -84,7 +73,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer {
} }
} }
func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) { func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr()) return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr())
} }

View File

@ -14,10 +14,9 @@ var (
type DefaultListener struct{} type DefaultListener struct{}
func (*DefaultListener) Listen(ctx context.Context, addr net.Addr) (net.Listener, error) { func (*DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
var lc net.ListenConfig var lc net.ListenConfig
sockopt := getSocketSettings(ctx)
if sockopt != nil { if sockopt != nil {
lc.Control = func(network, address string, c syscall.RawConn) error { lc.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
@ -31,10 +30,9 @@ func (*DefaultListener) Listen(ctx context.Context, addr net.Addr) (net.Listener
return lc.Listen(ctx, addr.Network(), addr.String()) return lc.Listen(ctx, addr.Network(), addr.String())
} }
func (*DefaultListener) ListenPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) { func (*DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) {
var lc net.ListenConfig var lc net.ListenConfig
sockopt := getSocketSettings(ctx)
if sockopt != nil { if sockopt != nil {
lc.Control = func(network, address string, c syscall.RawConn) error { lc.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {

View File

@ -10,28 +10,20 @@ import (
"v2ray.com/core/transport/internet/tls" "v2ray.com/core/transport/internet/tls"
) )
func getTCPSettingsFromContext(ctx context.Context) *Config {
rawTCPSettings := internet.StreamSettingsFromContext(ctx)
if rawTCPSettings == nil {
return nil
}
return rawTCPSettings.ProtocolSettings.(*Config)
}
// Dial dials a new TCP connection to the given destination. // Dial dials a new TCP connection to the given destination.
func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
newError("dialing TCP to ", dest).WriteToLog(session.ExportIDToError(ctx)) newError("dialing TCP to ", dest).WriteToLog(session.ExportIDToError(ctx))
conn, err := internet.DialSystem(ctx, dest) conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config := tls.ConfigFromContext(ctx); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
conn = tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2"))) conn = tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")))
} }
tcpSettings := getTCPSettingsFromContext(ctx) tcpSettings := streamSettings.ProtocolSettings.(*Config)
if tcpSettings != nil && tcpSettings.HeaderSettings != nil { if tcpSettings.HeaderSettings != nil {
headerConfig, err := tcpSettings.HeaderSettings.GetInstance() headerConfig, err := tcpSettings.HeaderSettings.GetInstance()
if err != nil { if err != nil {
return nil, newError("failed to get header settings").Base(err).AtError() return nil, newError("failed to get header settings").Base(err).AtError()

View File

@ -22,25 +22,24 @@ type Listener struct {
} }
// ListenTCP creates a new Listener based on configurations. // ListenTCP creates a new Listener based on configurations.
func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
}) }, streamSettings.SocketSettings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
tcpSettings := getTCPSettingsFromContext(ctx) tcpSettings := streamSettings.ProtocolSettings.(*Config)
l := &Listener{ l := &Listener{
listener: listener, listener: listener,
config: tcpSettings, config: tcpSettings,
addConn: handler, addConn: handler,
} }
if config := tls.ConfigFromContext(ctx); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2")) l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2"))
} }

View File

@ -20,22 +20,20 @@ func RegisterTransportListener(protocol string, listener ListenFunc) error {
type ConnHandler func(Connection) type ConnHandler func(Connection)
type ListenFunc func(ctx context.Context, address net.Address, port net.Port, handler ConnHandler) (Listener, error) type ListenFunc func(ctx context.Context, address net.Address, port net.Port, settings *MemoryStreamConfig, handler ConnHandler) (Listener, error)
type Listener interface { type Listener interface {
Close() error Close() error
Addr() net.Addr Addr() net.Addr
} }
func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler ConnHandler) (Listener, error) { func ListenTCP(ctx context.Context, address net.Address, port net.Port, settings *MemoryStreamConfig, handler ConnHandler) (Listener, error) {
settings := StreamSettingsFromContext(ctx)
if settings == nil { if settings == nil {
s, err := ToMemoryStreamConfig(nil) s, err := ToMemoryStreamConfig(nil)
if err != nil { if err != nil {
return nil, newError("failed to create default stream settings").Base(err) return nil, newError("failed to create default stream settings").Base(err)
} }
settings = s settings = s
ctx = ContextWithStreamSettings(ctx, settings)
} }
if address.Family().IsDomain() && address.Domain() == "localhost" { if address.Family().IsDomain() && address.Domain() == "localhost" {
@ -47,17 +45,17 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler
if listenFunc == nil { if listenFunc == nil {
return nil, newError(protocol, " listener not registered.").AtError() return nil, newError(protocol, " listener not registered.").AtError()
} }
listener, err := listenFunc(ctx, address, port, handler) listener, err := listenFunc(ctx, address, port, settings, handler)
if err != nil { if err != nil {
return nil, newError("failed to listen on address: ", address, ":", port).Base(err) return nil, newError("failed to listen on address: ", address, ":", port).Base(err)
} }
return listener, nil return listener, nil
} }
func ListenSystem(ctx context.Context, addr net.Addr) (net.Listener, error) { func ListenSystem(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
return effectiveListener.Listen(ctx, addr) return effectiveListener.Listen(ctx, addr, sockopt)
} }
func ListenSystemPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) { func ListenSystemPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) {
return effectiveListener.ListenPacket(ctx, addr) return effectiveListener.ListenPacket(ctx, addr, sockopt)
} }

View File

@ -1,7 +1,6 @@
package tls package tls
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"sync" "sync"
@ -215,13 +214,12 @@ func WithNextProto(protocol ...string) Option {
} }
} }
// ConfigFromContext fetches Config from context. Nil if not found. // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found.
func ConfigFromContext(ctx context.Context) *Config { func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config {
streamSettings := internet.StreamSettingsFromContext(ctx) if settings == nil {
if streamSettings == nil {
return nil return nil
} }
config, ok := streamSettings.SecuritySettings.(*Config) config, ok := settings.SecuritySettings.(*Config)
if !ok { if !ok {
return nil return nil
} }

View File

@ -10,8 +10,12 @@ import (
func init() { func init() {
common.Must(internet.RegisterTransportDialer(protocolName, common.Must(internet.RegisterTransportDialer(protocolName,
func(ctx context.Context, dest net.Destination) (internet.Connection, error) { func(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
conn, err := internet.DialSystem(ctx, dest) var sockopt *internet.SocketConfig
if streamSettings != nil {
sockopt = streamSettings.SocketSettings
}
conn, err := internet.DialSystem(ctx, dest, sockopt)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -36,7 +36,7 @@ type Hub struct {
recvOrigDest bool recvOrigDest bool
} }
func ListenUDP(ctx context.Context, address net.Address, port net.Port, options ...HubOption) (*Hub, error) { func ListenUDP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, options ...HubOption) (*Hub, error) {
hub := &Hub{ hub := &Hub{
capacity: 256, capacity: 256,
recvOrigDest: false, recvOrigDest: false,
@ -45,15 +45,18 @@ func ListenUDP(ctx context.Context, address net.Address, port net.Port, options
opt(hub) opt(hub)
} }
streamSettings := internet.StreamSettingsFromContext(ctx) var sockopt *internet.SocketConfig
if streamSettings != nil && streamSettings.SocketSettings != nil && streamSettings.SocketSettings.ReceiveOriginalDestAddress { if streamSettings != nil {
sockopt = streamSettings.SocketSettings
}
if sockopt != nil && sockopt.ReceiveOriginalDestAddress {
hub.recvOrigDest = true hub.recvOrigDest = true
} }
udpConn, err := internet.ListenSystemPacket(ctx, &net.UDPAddr{ udpConn, err := internet.ListenSystemPacket(ctx, &net.UDPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
}) }, sockopt)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -14,10 +14,10 @@ import (
) )
// Dial dials a WebSocket connection to the given destination. // Dial dials a WebSocket connection to the given destination.
func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
conn, err := dialWebsocket(ctx, dest) conn, err := dialWebsocket(ctx, dest, streamSettings)
if err != nil { if err != nil {
return nil, newError("failed to dial WebSocket").Base(err) return nil, newError("failed to dial WebSocket").Base(err)
} }
@ -28,12 +28,12 @@ func init() {
common.Must(internet.RegisterTransportDialer(protocolName, Dial)) common.Must(internet.RegisterTransportDialer(protocolName, Dial))
} }
func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) { func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
wsSettings := internet.StreamSettingsFromContext(ctx).ProtocolSettings.(*Config) wsSettings := streamSettings.ProtocolSettings.(*Config)
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) { NetDial: func(network, addr string) (net.Conn, error) {
return internet.DialSystem(ctx, dest) return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
}, },
ReadBufferSize: 4 * 1024, ReadBufferSize: 4 * 1024,
WriteBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024,
@ -42,7 +42,7 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
protocol := "ws" protocol := "ws"
if config := tls.ConfigFromContext(ctx); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
protocol = "wss" protocol = "wss"
dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest)) dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest))
} }

View File

@ -55,16 +55,15 @@ type Listener struct {
addConn internet.ConnHandler addConn internet.ConnHandler
} }
func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (internet.Listener, error) { func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
networkSettings := internet.StreamSettingsFromContext(ctx) wsSettings := streamSettings.ProtocolSettings.(*Config)
wsSettings := networkSettings.ProtocolSettings.(*Config)
var tlsConfig *tls.Config var tlsConfig *tls.Config
if config := v2tls.ConfigFromContext(ctx); config != nil { if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
tlsConfig = config.GetTLSConfig() tlsConfig = config.GetTLSConfig()
} }
listener, err := listenTCP(ctx, address, port, tlsConfig) listener, err := listenTCP(ctx, address, port, tlsConfig, streamSettings.SocketSettings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -84,11 +83,11 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i
return l, err return l, err
} }
func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfig *tls.Config) (net.Listener, error) { func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (net.Listener, error) {
listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
}) }, sockopt)
if err != nil { if err != nil {
return nil, newError("failed to listen TCP on", address, ":", port).Base(err) return nil, newError("failed to listen TCP on", address, ":", port).Base(err)
} }

View File

@ -18,13 +18,12 @@ import (
func Test_listenWSAndDial(t *testing.T) { func Test_listenWSAndDial(t *testing.T) {
assert := With(t) assert := With(t)
lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "ws", Path: "ws",
}, },
}) }, func(conn internet.Connection) {
listen, err := ListenWS(lctx, net.LocalHostIP, 13146, func(conn internet.Connection) {
go func(c internet.Connection) { go func(c internet.Connection) {
defer c.Close() defer c.Close()
@ -42,11 +41,12 @@ func Test_listenWSAndDial(t *testing.T) {
}) })
assert(err, IsNil) assert(err, IsNil)
ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ ctx := context.Background()
streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{Path: "ws"}, ProtocolSettings: &Config{Path: "ws"},
}) }
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
assert(err, IsNil) assert(err, IsNil)
_, err = conn.Write([]byte("Test connection 1")) _, err = conn.Write([]byte("Test connection 1"))
@ -59,7 +59,7 @@ func Test_listenWSAndDial(t *testing.T) {
assert(conn.Close(), IsNil) assert(conn.Close(), IsNil)
<-time.After(time.Second * 5) <-time.After(time.Second * 5)
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
assert(err, IsNil) assert(err, IsNil)
_, err = conn.Write([]byte("Test connection 2")) _, err = conn.Write([]byte("Test connection 2"))
assert(err, IsNil) assert(err, IsNil)
@ -73,13 +73,12 @@ func Test_listenWSAndDial(t *testing.T) {
func TestDialWithRemoteAddr(t *testing.T) { func TestDialWithRemoteAddr(t *testing.T) {
assert := With(t) assert := With(t)
lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "ws", Path: "ws",
}, },
}) }, func(conn internet.Connection) {
listen, err := ListenWS(lctx, net.LocalHostIP, 13148, func(conn internet.Connection) {
go func(c internet.Connection) { go func(c internet.Connection) {
defer c.Close() defer c.Close()
@ -99,11 +98,10 @@ func TestDialWithRemoteAddr(t *testing.T) {
}) })
assert(err, IsNil) assert(err, IsNil)
ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13148), &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{Path: "ws", Header: []*Header{{Key: "X-Forwarded-For", Value: "1.1.1.1"}}}, ProtocolSettings: &Config{Path: "ws", Header: []*Header{{Key: "X-Forwarded-For", Value: "1.1.1.1"}}},
}) })
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13148))
assert(err, IsNil) assert(err, IsNil)
_, err = conn.Write([]byte("Test connection 1")) _, err = conn.Write([]byte("Test connection 1"))
@ -126,7 +124,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
start := time.Now() start := time.Now()
ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ streamSettings := &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
Path: "wss", Path: "wss",
@ -136,9 +134,8 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
AllowInsecure: true, AllowInsecure: true,
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
}, },
}) }
listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn internet.Connection) {
listen, err := ListenWS(ctx, net.LocalHostIP, 13143, func(conn internet.Connection) {
go func() { go func() {
_ = conn.Close() _ = conn.Close()
}() }()
@ -146,7 +143,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
defer listen.Close() defer listen.Close()
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13143)) conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13143), streamSettings)
assert(err, IsNil) assert(err, IsNil)
_ = conn.Close() _ = conn.Close()