mirror of https://github.com/XTLS/Xray-core
				
				
				
			Fix: gRPC & HTTP/2 dialer (#445)
							parent
							
								
									b63049f404
								
							
						
					
					
						commit
						3ed14c2fcd
					
				| 
						 | 
				
			
			@ -36,6 +36,7 @@ func init() {
 | 
			
		|||
type dialerConf struct {
 | 
			
		||||
	net.Destination
 | 
			
		||||
	*internet.SocketConfig
 | 
			
		||||
	*tls.Config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
| 
						 | 
				
			
			@ -46,14 +47,9 @@ var (
 | 
			
		|||
func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
 | 
			
		||||
	grpcSettings := streamSettings.ProtocolSettings.(*Config)
 | 
			
		||||
 | 
			
		||||
	config := tls.ConfigFromStreamSettings(streamSettings)
 | 
			
		||||
	var dialOption = grpc.WithInsecure()
 | 
			
		||||
	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 | 
			
		||||
 | 
			
		||||
	if config != nil {
 | 
			
		||||
		dialOption = grpc.WithTransportCredentials(credentials.NewTLS(config.GetTLSConfig()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := getGrpcClient(ctx, dest, dialOption, streamSettings.SocketSettings)
 | 
			
		||||
	conn, err := getGrpcClient(ctx, dest, tlsConfig, streamSettings.SocketSettings)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, newError("Cannot dial gRPC").Base(err)
 | 
			
		||||
| 
						 | 
				
			
			@ -76,7 +72,7 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
 | 
			
		|||
	return encoding.NewHunkConn(grpcService, nil), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) {
 | 
			
		||||
func getGrpcClient(ctx context.Context, dest net.Destination, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) {
 | 
			
		||||
	globalDialerAccess.Lock()
 | 
			
		||||
	defer globalDialerAccess.Unlock()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -84,10 +80,16 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
 | 
			
		|||
		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if client, found := globalDialerMap[dialerConf{dest, sockopt}]; found && client.GetState() != connectivity.Shutdown {
 | 
			
		||||
	if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsConfig}]; found && client.GetState() != connectivity.Shutdown {
 | 
			
		||||
		return client, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dialOption := grpc.WithInsecure()
 | 
			
		||||
 | 
			
		||||
	if tlsConfig != nil {
 | 
			
		||||
		dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig.GetTLSConfig()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := grpc.Dial(
 | 
			
		||||
		gonet.JoinHostPort(dest.Address.String(), dest.Port.String()),
 | 
			
		||||
		dialOption,
 | 
			
		||||
| 
						 | 
				
			
			@ -125,6 +127,6 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
 | 
			
		|||
			return internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
 | 
			
		||||
		}),
 | 
			
		||||
	)
 | 
			
		||||
	globalDialerMap[dialerConf{dest, sockopt}] = conn
 | 
			
		||||
	globalDialerMap[dialerConf{dest, sockopt, tlsConfig}] = conn
 | 
			
		||||
	return conn, err
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ import (
 | 
			
		|||
type dialerConf struct {
 | 
			
		||||
	net.Destination
 | 
			
		||||
	*internet.SocketConfig
 | 
			
		||||
	*tls.Config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
| 
						 | 
				
			
			@ -36,7 +37,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C
 | 
			
		|||
		globalDialerMap = make(map[dialerConf]*http.Client)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if client, found := globalDialerMap[dialerConf{dest, sockopt}]; found {
 | 
			
		||||
	if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsSettings}]; found {
 | 
			
		||||
		return client, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -92,7 +93,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C
 | 
			
		|||
		Transport: transport,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	globalDialerMap[dialerConf{dest, sockopt}] = client
 | 
			
		||||
	globalDialerMap[dialerConf{dest, sockopt, tlsSettings}] = client
 | 
			
		||||
	return client, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue