package remotedialer import ( "net/http" "sync" "time" "github.com/gorilla/websocket" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) var ( errFailedAuth = errors.New("failed authentication") errWrongMessageType = errors.New("wrong websocket message type") ) type Authorizer func(req *http.Request) (clientKey string, authed bool, err error) type ErrorWriter func(rw http.ResponseWriter, req *http.Request, code int, err error) func DefaultErrorWriter(rw http.ResponseWriter, req *http.Request, code int, err error) { rw.Write([]byte(err.Error())) rw.WriteHeader(code) } type Server struct { PeerID string PeerToken string authorizer Authorizer errorWriter ErrorWriter sessions *sessionManager peers map[string]peer peerLock sync.Mutex } func New(auth Authorizer, errorWriter ErrorWriter) *Server { return &Server{ peers: map[string]peer{}, authorizer: auth, errorWriter: errorWriter, sessions: newSessionManager(), } } func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) { clientKey, authed, peer, err := s.auth(req) if err != nil { s.errorWriter(rw, req, 400, err) return } if !authed { s.errorWriter(rw, req, 401, errFailedAuth) return } logrus.Infof("Handling backend connection request [%s]", clientKey) upgrader := websocket.Upgrader{ HandshakeTimeout: 5 * time.Second, CheckOrigin: func(r *http.Request) bool { return true }, Error: s.errorWriter, } wsConn, err := upgrader.Upgrade(rw, req, nil) if err != nil { s.errorWriter(rw, req, 400, errors.Wrapf(err, "Error during upgrade for host [%v]", clientKey)) return } session := s.sessions.add(clientKey, wsConn, peer) defer s.sessions.remove(session) // Don't need to associate req.Context() to the Session, it will cancel otherwise code, err := session.Serve() if err != nil { // Hijacked so we can't write to the client logrus.Infof("error in remotedialer server [%d]: %v", code, err) } } func (s *Server) auth(req *http.Request) (clientKey string, authed, peer bool, err error) { id := req.Header.Get(ID) token := req.Header.Get(Token) if id != "" && token != "" { // peer authentication s.peerLock.Lock() p, ok := s.peers[id] s.peerLock.Unlock() if ok && p.token == token { return id, true, true, nil } } id, authed, err = s.authorizer(req) return id, authed, false, err }