diff --git a/src/utils/vhost/vhost.go b/src/utils/vhost/vhost.go index 93279b3..12c8164 100644 --- a/src/utils/vhost/vhost.go +++ b/src/utils/vhost/vhost.go @@ -71,6 +71,18 @@ func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err e func (v *VhostMuxer) getListener(name string) (l *Listener, exist bool) { v.mutex.RLock() defer v.mutex.RUnlock() + // first we check the full hostname + // if not exist, then check the wildcard_domain such as *.example.com + l, exist = v.registryMap[name] + if exist { + return l, exist + } + domainSplit := strings.Split(name, ".") + if len(domainSplit) < 3 { + return l, false + } + domainSplit[0] = "*" + name = strings.Join(domainSplit, ".") l, exist = v.registryMap[name] return l, exist } @@ -93,21 +105,26 @@ func (v *VhostMuxer) run() { func (v *VhostMuxer) handle(c *conn.Conn) { if err := c.SetDeadline(time.Now().Add(v.timeout)); err != nil { + c.Close() return } sConn, name, err := v.vhostFunc(c) if err != nil { + c.Close() return } name = strings.ToLower(name) + // get listener by hostname l, ok := v.getListener(name) if !ok { + c.Close() return } if err = sConn.SetDeadline(time.Time{}); err != nil { + c.Close() return } c.SetTcpConn(sConn)