From d270a99d4f90a49b0eabbce5ca41bae7e754d6e5 Mon Sep 17 00:00:00 2001
From: Darien Raymond <admin@v2ray.com>
Date: Thu, 9 Feb 2017 12:31:40 +0100
Subject: [PATCH] simplify code

---
 transport/internet/websocket/hub.go | 97 ++++++++++++++++-------------
 1 file changed, 52 insertions(+), 45 deletions(-)

diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go
index 0619ff1e..1487bdba 100644
--- a/transport/internet/websocket/hub.go
+++ b/transport/internet/websocket/hub.go
@@ -27,7 +27,30 @@ type ConnectionWithError struct {
 	err  error
 }
 
-type WSListener struct {
+type requestHandler struct {
+	path  string
+	conns chan *ConnectionWithError
+}
+
+func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+	if request.URL.Path != h.path {
+		writer.WriteHeader(http.StatusNotFound)
+		return
+	}
+	conn, err := converttovws(writer, request)
+	if err != nil {
+		log.Info("WebSocket|Listener: Failed to convert to WebSocket connection: ", err)
+		return
+	}
+
+	select {
+	case h.conns <- &ConnectionWithError{conn: conn}:
+	default:
+		conn.Close()
+	}
+}
+
+type Listener struct {
 	sync.Mutex
 	acccepting    bool
 	awaitingConns chan *ConnectionWithError
@@ -43,7 +66,7 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt
 	}
 	wsSettings := networkSettings.(*Config)
 
-	l := &WSListener{
+	l := &Listener{
 		acccepting:    true,
 		awaitingConns: make(chan *ConnectionWithError, 32),
 		config:        wsSettings,
@@ -65,51 +88,35 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt
 	return l, err
 }
 
-func (wsl *WSListener) listenws(address v2net.Address, port v2net.Port) error {
-	http.HandleFunc("/"+wsl.config.Path, func(w http.ResponseWriter, r *http.Request) {
-		conn, err := wsl.converttovws(w, r)
-		if err != nil {
-			log.Warning("WebSocket|Listener: Failed to convert connection: ", err)
-			return
-		}
-
-		select {
-		case wsl.awaitingConns <- &ConnectionWithError{
-			conn: conn,
-		}:
-		default:
-			if conn != nil {
-				conn.Close()
-			}
-		}
-		return
-	})
-
+func (ln *Listener) listenws(address v2net.Address, port v2net.Port) error {
 	netAddr := address.String() + ":" + strconv.Itoa(int(port.Value()))
 	var listener net.Listener
-	if wsl.tlsConfig == nil {
+	if ln.tlsConfig == nil {
 		l, err := net.Listen("tcp", netAddr)
 		if err != nil {
 			return errors.Base(err).Message("WebSocket|Listener: Failed to listen TCP ", netAddr)
 		}
 		listener = l
 	} else {
-		l, err := tls.Listen("tcp", netAddr, wsl.tlsConfig)
+		l, err := tls.Listen("tcp", netAddr, ln.tlsConfig)
 		if err != nil {
 			return errors.Base(err).Message("WebSocket|Listener: Failed to listen TLS ", netAddr)
 		}
 		listener = l
 	}
-	wsl.listener = listener
+	ln.listener = listener
 
 	go func() {
-		http.Serve(listener, nil)
+		http.Serve(listener, &requestHandler{
+			path:  "/" + ln.config.Path,
+			conns: ln.awaitingConns,
+		})
 	}()
 
 	return nil
 }
 
-func (wsl *WSListener) converttovws(w http.ResponseWriter, r *http.Request) (*wsconn, error) {
+func converttovws(w http.ResponseWriter, r *http.Request) (*wsconn, error) {
 	var upgrader = websocket.Upgrader{
 		ReadBufferSize:  32 * 1024,
 		WriteBufferSize: 32 * 1024,
@@ -123,49 +130,49 @@ func (wsl *WSListener) converttovws(w http.ResponseWriter, r *http.Request) (*ws
 	return &wsconn{wsc: conn}, nil
 }
 
-func (v *WSListener) Accept() (internet.Connection, error) {
-	for v.acccepting {
+func (ln *Listener) Accept() (internet.Connection, error) {
+	for ln.acccepting {
 		select {
-		case connErr, open := <-v.awaitingConns:
+		case connErr, open := <-ln.awaitingConns:
 			if !open {
 				return nil, ErrClosedListener
 			}
 			if connErr.err != nil {
 				return nil, connErr.err
 			}
-			return internal.NewConnection(internal.ConnectionID{}, connErr.conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())), nil
+			return internal.NewConnection(internal.ConnectionID{}, connErr.conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())), nil
 		case <-time.After(time.Second * 2):
 		}
 	}
 	return nil, ErrClosedListener
 }
 
-func (v *WSListener) Put(id internal.ConnectionID, conn net.Conn) {
-	v.Lock()
-	defer v.Unlock()
-	if !v.acccepting {
+func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
+	ln.Lock()
+	defer ln.Unlock()
+	if !ln.acccepting {
 		return
 	}
 	select {
-	case v.awaitingConns <- &ConnectionWithError{conn: conn}:
+	case ln.awaitingConns <- &ConnectionWithError{conn: conn}:
 	default:
 		conn.Close()
 	}
 }
 
-func (v *WSListener) Addr() net.Addr {
-	return nil
+func (ln *Listener) Addr() net.Addr {
+	return ln.listener.Addr()
 }
 
-func (v *WSListener) Close() error {
-	v.Lock()
-	defer v.Unlock()
-	v.acccepting = false
+func (ln *Listener) Close() error {
+	ln.Lock()
+	defer ln.Unlock()
+	ln.acccepting = false
 
-	v.listener.Close()
+	ln.listener.Close()
 
-	close(v.awaitingConns)
-	for connErr := range v.awaitingConns {
+	close(ln.awaitingConns)
+	for connErr := range ln.awaitingConns {
 		if connErr.conn != nil {
 			connErr.conn.Close()
 		}