mirror of https://github.com/fatedier/frp
sshTunnelGateway refactor (#3784)
parent
8b432e179d
commit
d5b41f1e14
4
Makefile
4
Makefile
|
@ -26,10 +26,10 @@ vet:
|
||||||
go vet ./...
|
go vet ./...
|
||||||
|
|
||||||
frps:
|
frps:
|
||||||
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -o bin/frps ./cmd/frps
|
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frps -o bin/frps ./cmd/frps
|
||||||
|
|
||||||
frpc:
|
frpc:
|
||||||
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -o bin/frpc ./cmd/frpc
|
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frpc -o bin/frpc ./cmd/frpc
|
||||||
|
|
||||||
test: gotest
|
test: gotest
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,223 @@
|
||||||
|
// Copyright 2023 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
libdial "github.com/fatedier/golib/net/dial"
|
||||||
|
fmux "github.com/hashicorp/yamux"
|
||||||
|
quic "github.com/quic-go/quic-go"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
|
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connector is a interface for establishing connections to the server.
|
||||||
|
type Connector interface {
|
||||||
|
Open() error
|
||||||
|
Connect() (net.Conn, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
||||||
|
type defaultConnectorImpl struct {
|
||||||
|
ctx context.Context
|
||||||
|
cfg *v1.ClientCommonConfig
|
||||||
|
|
||||||
|
muxSession *fmux.Session
|
||||||
|
quicConn quic.Connection
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector {
|
||||||
|
return &defaultConnectorImpl{
|
||||||
|
ctx: ctx,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open opens a underlying connection to the server.
|
||||||
|
// The underlying connection is either a TCP connection or a QUIC connection.
|
||||||
|
// After the underlying connection is established, you can call Connect() to get a stream.
|
||||||
|
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
|
||||||
|
func (c *defaultConnectorImpl) Open() error {
|
||||||
|
xl := xlog.FromContextSafe(c.ctx)
|
||||||
|
|
||||||
|
// special for quic
|
||||||
|
if strings.EqualFold(c.cfg.Transport.Protocol, "quic") {
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
var err error
|
||||||
|
sn := c.cfg.Transport.TLS.ServerName
|
||||||
|
if sn == "" {
|
||||||
|
sn = c.cfg.ServerAddr
|
||||||
|
}
|
||||||
|
if lo.FromPtr(c.cfg.Transport.TLS.Enable) {
|
||||||
|
tlsConfig, err = transport.NewClientTLSConfig(
|
||||||
|
c.cfg.Transport.TLS.CertFile,
|
||||||
|
c.cfg.Transport.TLS.KeyFile,
|
||||||
|
c.cfg.Transport.TLS.TrustedCaFile,
|
||||||
|
sn)
|
||||||
|
} else {
|
||||||
|
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
xl.Warn("fail to build tls configuration, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tlsConfig.NextProtos = []string{"frp"}
|
||||||
|
|
||||||
|
conn, err := quic.DialAddr(
|
||||||
|
c.ctx,
|
||||||
|
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)),
|
||||||
|
tlsConfig, &quic.Config{
|
||||||
|
MaxIdleTimeout: time.Duration(c.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
|
||||||
|
MaxIncomingStreams: int64(c.cfg.Transport.QUIC.MaxIncomingStreams),
|
||||||
|
KeepAlivePeriod: time.Duration(c.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.quicConn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !lo.FromPtr(c.cfg.Transport.TCPMux) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := c.realConnect()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmuxCfg := fmux.DefaultConfig()
|
||||||
|
fmuxCfg.KeepAliveInterval = time.Duration(c.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
||||||
|
fmuxCfg.LogOutput = io.Discard
|
||||||
|
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
|
||||||
|
session, err := fmux.Client(conn, fmuxCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.muxSession = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
|
||||||
|
func (c *defaultConnectorImpl) Connect() (net.Conn, error) {
|
||||||
|
if c.quicConn != nil {
|
||||||
|
stream, err := c.quicConn.OpenStreamSync(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return utilnet.QuicStreamToNetConn(stream, c.quicConn), nil
|
||||||
|
} else if c.muxSession != nil {
|
||||||
|
stream, err := c.muxSession.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.realConnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *defaultConnectorImpl) realConnect() (net.Conn, error) {
|
||||||
|
xl := xlog.FromContextSafe(c.ctx)
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
var err error
|
||||||
|
tlsEnable := lo.FromPtr(c.cfg.Transport.TLS.Enable)
|
||||||
|
if c.cfg.Transport.Protocol == "wss" {
|
||||||
|
tlsEnable = true
|
||||||
|
}
|
||||||
|
if tlsEnable {
|
||||||
|
sn := c.cfg.Transport.TLS.ServerName
|
||||||
|
if sn == "" {
|
||||||
|
sn = c.cfg.ServerAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err = transport.NewClientTLSConfig(
|
||||||
|
c.cfg.Transport.TLS.CertFile,
|
||||||
|
c.cfg.Transport.TLS.KeyFile,
|
||||||
|
c.cfg.Transport.TLS.TrustedCaFile,
|
||||||
|
sn)
|
||||||
|
if err != nil {
|
||||||
|
xl.Warn("fail to build tls configuration, err: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyType, addr, auth, err := libdial.ParseProxyURL(c.cfg.Transport.ProxyURL)
|
||||||
|
if err != nil {
|
||||||
|
xl.Error("fail to parse proxy url")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dialOptions := []libdial.DialOption{}
|
||||||
|
protocol := c.cfg.Transport.Protocol
|
||||||
|
switch protocol {
|
||||||
|
case "websocket":
|
||||||
|
protocol = "tcp"
|
||||||
|
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
|
||||||
|
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
||||||
|
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
||||||
|
}))
|
||||||
|
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
||||||
|
case "wss":
|
||||||
|
protocol = "tcp"
|
||||||
|
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
|
||||||
|
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
||||||
|
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
|
||||||
|
default:
|
||||||
|
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
||||||
|
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
||||||
|
}))
|
||||||
|
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.cfg.Transport.ConnectServerLocalIP != "" {
|
||||||
|
dialOptions = append(dialOptions, libdial.WithLocalAddr(c.cfg.Transport.ConnectServerLocalIP))
|
||||||
|
}
|
||||||
|
dialOptions = append(dialOptions,
|
||||||
|
libdial.WithProtocol(protocol),
|
||||||
|
libdial.WithTimeout(time.Duration(c.cfg.Transport.DialServerTimeout)*time.Second),
|
||||||
|
libdial.WithKeepAlive(time.Duration(c.cfg.Transport.DialServerKeepAlive)*time.Second),
|
||||||
|
libdial.WithProxy(proxyType, addr),
|
||||||
|
libdial.WithProxyAuth(auth),
|
||||||
|
)
|
||||||
|
conn, err := libdial.DialContext(
|
||||||
|
c.ctx,
|
||||||
|
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)),
|
||||||
|
dialOptions...,
|
||||||
|
)
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *defaultConnectorImpl) Close() error {
|
||||||
|
if c.quicConn != nil {
|
||||||
|
_ = c.quicConn.CloseWithError(0, "")
|
||||||
|
}
|
||||||
|
if c.muxSession != nil {
|
||||||
|
_ = c.muxSession.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -58,8 +58,8 @@ type Control struct {
|
||||||
// control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
|
// control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
|
||||||
// use cm to create new connections, which could be real TCP connections or virtual streams.
|
// use connector to create new connections, which could be real TCP connections or virtual streams.
|
||||||
cm *ConnectionManager
|
connector Connector
|
||||||
|
|
||||||
doneCh chan struct{}
|
doneCh chan struct{}
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ type Control struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewControl(
|
func NewControl(
|
||||||
ctx context.Context, runID string, conn net.Conn, cm *ConnectionManager,
|
ctx context.Context, runID string, conn net.Conn, connector Connector,
|
||||||
clientCfg *v1.ClientCommonConfig,
|
clientCfg *v1.ClientCommonConfig,
|
||||||
pxyCfgs []v1.ProxyConfigurer,
|
pxyCfgs []v1.ProxyConfigurer,
|
||||||
visitorCfgs []v1.VisitorConfigurer,
|
visitorCfgs []v1.VisitorConfigurer,
|
||||||
|
@ -92,7 +92,7 @@ func NewControl(
|
||||||
runID: runID,
|
runID: runID,
|
||||||
pxyCfgs: pxyCfgs,
|
pxyCfgs: pxyCfgs,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
cm: cm,
|
connector: connector,
|
||||||
doneCh: make(chan struct{}),
|
doneCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
ctl.lastPong.Store(time.Now())
|
ctl.lastPong.Store(time.Now())
|
||||||
|
@ -122,6 +122,10 @@ func (ctl *Control) Run() {
|
||||||
go ctl.vm.Run()
|
go ctl.vm.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ctl *Control) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||||
|
ctl.pm.SetInWorkConnCallback(cb)
|
||||||
|
}
|
||||||
|
|
||||||
func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
||||||
xl := ctl.xl
|
xl := ctl.xl
|
||||||
workConn, err := ctl.connectServer()
|
workConn, err := ctl.connectServer()
|
||||||
|
@ -207,7 +211,7 @@ func (ctl *Control) GracefulClose(d time.Duration) error {
|
||||||
time.Sleep(d)
|
time.Sleep(d)
|
||||||
|
|
||||||
ctl.conn.Close()
|
ctl.conn.Close()
|
||||||
ctl.cm.Close()
|
ctl.connector.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,7 +222,7 @@ func (ctl *Control) Done() <-chan struct{} {
|
||||||
|
|
||||||
// connectServer return a new connection to frps
|
// connectServer return a new connection to frps
|
||||||
func (ctl *Control) connectServer() (conn net.Conn, err error) {
|
func (ctl *Control) connectServer() (conn net.Conn, err error) {
|
||||||
return ctl.cm.Connect()
|
return ctl.connector.Connect()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctl *Control) registerMsgHandlers() {
|
func (ctl *Control) registerMsgHandlers() {
|
||||||
|
@ -282,7 +286,7 @@ func (ctl *Control) worker() {
|
||||||
|
|
||||||
ctl.pm.Close()
|
ctl.pm.Close()
|
||||||
ctl.vm.Close()
|
ctl.vm.Close()
|
||||||
ctl.cm.Close()
|
ctl.connector.Close()
|
||||||
|
|
||||||
close(ctl.doneCh)
|
close(ctl.doneCh)
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,10 +47,9 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, v
|
||||||
// Proxy defines how to handle work connections for different proxy type.
|
// Proxy defines how to handle work connections for different proxy type.
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
Run() error
|
Run() error
|
||||||
|
|
||||||
// InWorkConn accept work connections registered to server.
|
// InWorkConn accept work connections registered to server.
|
||||||
InWorkConn(net.Conn, *msg.StartWorkConn)
|
InWorkConn(net.Conn, *msg.StartWorkConn)
|
||||||
|
SetInWorkConnCallback(func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) /* continue */ bool)
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +88,8 @@ type BaseProxy struct {
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
// proxyPlugin is used to handle connections instead of dialing to local service.
|
// proxyPlugin is used to handle connections instead of dialing to local service.
|
||||||
// It's only validate for TCP protocol now.
|
// It's only validate for TCP protocol now.
|
||||||
proxyPlugin plugin.Plugin
|
proxyPlugin plugin.Plugin
|
||||||
|
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) /* continue */ bool
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
xl *xlog.Logger
|
xl *xlog.Logger
|
||||||
|
@ -113,7 +113,16 @@ func (pxy *BaseProxy) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||||
|
pxy.inWorkConnCallback = cb
|
||||||
|
}
|
||||||
|
|
||||||
func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
|
func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
|
||||||
|
if pxy.inWorkConnCallback != nil {
|
||||||
|
if !pxy.inWorkConnCallback(pxy.baseCfg, conn, m) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Auth.Token))
|
pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Auth.Token))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +141,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t",
|
xl.Trace("handle tcp work connection, useEncryption: %t, useCompression: %t",
|
||||||
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
|
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
|
||||||
if baseCfg.Transport.UseEncryption {
|
if baseCfg.Transport.UseEncryption {
|
||||||
remote, err = libio.WithEncryption(remote, encKey)
|
remote, err = libio.WithEncryption(remote, encKey)
|
||||||
|
|
|
@ -31,8 +31,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
proxies map[string]*Wrapper
|
proxies map[string]*Wrapper
|
||||||
msgTransporter transport.MessageTransporter
|
msgTransporter transport.MessageTransporter
|
||||||
|
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
@ -71,6 +72,10 @@ func (pm *Manager) StartProxy(name string, remoteAddr string, serverRespErr stri
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *Manager) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||||
|
pm.inWorkConnCallback = cb
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *Manager) Close() {
|
func (pm *Manager) Close() {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
@ -146,6 +151,9 @@ func (pm *Manager) Reload(pxyCfgs []v1.ProxyConfigurer) {
|
||||||
name := cfg.GetBaseConfig().Name
|
name := cfg.GetBaseConfig().Name
|
||||||
if _, ok := pm.proxies[name]; !ok {
|
if _, ok := pm.proxies[name]; !ok {
|
||||||
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter)
|
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter)
|
||||||
|
if pm.inWorkConnCallback != nil {
|
||||||
|
pxy.SetInWorkConnCallback(pm.inWorkConnCallback)
|
||||||
|
}
|
||||||
pm.proxies[name] = pxy
|
pm.proxies[name] = pxy
|
||||||
addPxyNames = append(addPxyNames, name)
|
addPxyNames = append(addPxyNames, name)
|
||||||
|
|
||||||
|
|
|
@ -121,6 +121,10 @@ func NewWrapper(
|
||||||
return pw
|
return pw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pw *Wrapper) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||||
|
pw.pxy.SetInWorkConnCallback(cb)
|
||||||
|
}
|
||||||
|
|
||||||
func (pw *Wrapper) SetRunningStatus(remoteAddr string, respErr string) error {
|
func (pw *Wrapper) SetRunningStatus(remoteAddr string, respErr string) error {
|
||||||
pw.mu.Lock()
|
pw.mu.Lock()
|
||||||
defer pw.mu.Unlock()
|
defer pw.mu.Unlock()
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -16,30 +16,22 @@ package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fatedier/golib/crypto"
|
"github.com/fatedier/golib/crypto"
|
||||||
libdial "github.com/fatedier/golib/net/dial"
|
|
||||||
fmux "github.com/hashicorp/yamux"
|
|
||||||
quic "github.com/quic-go/quic-go"
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
|
||||||
"github.com/fatedier/frp/assets"
|
"github.com/fatedier/frp/assets"
|
||||||
"github.com/fatedier/frp/pkg/auth"
|
"github.com/fatedier/frp/pkg/auth"
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/msg"
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
"github.com/fatedier/frp/pkg/util/version"
|
"github.com/fatedier/frp/pkg/util/version"
|
||||||
"github.com/fatedier/frp/pkg/util/wait"
|
"github.com/fatedier/frp/pkg/util/wait"
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
|
@ -75,6 +67,9 @@ type Service struct {
|
||||||
// call cancel to stop service
|
// call cancel to stop service
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
gracefulDuration time.Duration
|
gracefulDuration time.Duration
|
||||||
|
|
||||||
|
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
|
||||||
|
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(
|
func NewService(
|
||||||
|
@ -84,15 +79,24 @@ func NewService(
|
||||||
cfgFile string,
|
cfgFile string,
|
||||||
) *Service {
|
) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
authSetter: auth.NewAuthSetter(cfg.Auth),
|
authSetter: auth.NewAuthSetter(cfg.Auth),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
pxyCfgs: pxyCfgs,
|
pxyCfgs: pxyCfgs,
|
||||||
visitorCfgs: visitorCfgs,
|
visitorCfgs: visitorCfgs,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
|
connectorCreator: NewConnector,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (svr *Service) SetConnectorCreator(h func(context.Context, *v1.ClientCommonConfig) Connector) {
|
||||||
|
svr.connectorCreator = h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (svr *Service) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||||
|
svr.inWorkConnCallback = cb
|
||||||
|
}
|
||||||
|
|
||||||
func (svr *Service) GetController() *Control {
|
func (svr *Service) GetController() *Control {
|
||||||
svr.ctlMu.RLock()
|
svr.ctlMu.RLock()
|
||||||
defer svr.ctlMu.RUnlock()
|
defer svr.ctlMu.RUnlock()
|
||||||
|
@ -101,7 +105,7 @@ func (svr *Service) GetController() *Control {
|
||||||
|
|
||||||
func (svr *Service) Run(ctx context.Context) error {
|
func (svr *Service) Run(ctx context.Context) error {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
svr.ctx = xlog.NewContext(ctx, xlog.New())
|
svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx))
|
||||||
svr.cancel = cancel
|
svr.cancel = cancel
|
||||||
|
|
||||||
// set custom DNSServer
|
// set custom DNSServer
|
||||||
|
@ -173,21 +177,20 @@ func (svr *Service) keepControllerWorking() {
|
||||||
// login creates a connection to frps and registers it self as a client
|
// login creates a connection to frps and registers it self as a client
|
||||||
// conn: control connection
|
// conn: control connection
|
||||||
// session: if it's not nil, using tcp mux
|
// session: if it's not nil, using tcp mux
|
||||||
func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
||||||
xl := xlog.FromContextSafe(svr.ctx)
|
xl := xlog.FromContextSafe(svr.ctx)
|
||||||
cm = NewConnectionManager(svr.ctx, svr.cfg)
|
connector = svr.connectorCreator(svr.ctx, svr.cfg)
|
||||||
|
if err = connector.Open(); err != nil {
|
||||||
if err = cm.OpenConnection(); err != nil {
|
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cm.Close()
|
connector.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, err = cm.Connect()
|
conn, err = connector.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -226,8 +229,7 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
svr.runID = loginRespMsg.RunID
|
svr.runID = loginRespMsg.RunID
|
||||||
xl.ResetPrefixes()
|
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||||
xl.AppendPrefix(svr.runID)
|
|
||||||
|
|
||||||
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID)
|
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID)
|
||||||
return
|
return
|
||||||
|
@ -239,7 +241,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
||||||
|
|
||||||
loginFunc := func() error {
|
loginFunc := func() error {
|
||||||
xl.Info("try to connect to server...")
|
xl.Info("try to connect to server...")
|
||||||
conn, cm, err := svr.login()
|
conn, connector, err := svr.login()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xl.Warn("connect to server error: %v", err)
|
xl.Warn("connect to server error: %v", err)
|
||||||
if firstLoginExit {
|
if firstLoginExit {
|
||||||
|
@ -248,13 +250,14 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctl, err := NewControl(svr.ctx, svr.runID, conn, cm,
|
ctl, err := NewControl(svr.ctx, svr.runID, conn, connector,
|
||||||
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
|
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
xl.Error("NewControl error: %v", err)
|
xl.Error("NewControl error: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
ctl.SetInWorkConnCallback(svr.inWorkConnCallback)
|
||||||
|
|
||||||
ctl.Run()
|
ctl.Run()
|
||||||
// close and replace previous control
|
// close and replace previous control
|
||||||
|
@ -314,184 +317,3 @@ func (svr *Service) stop() {
|
||||||
svr.ctl = nil
|
svr.ctl = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionManager is a wrapper for establishing connections to the server.
|
|
||||||
type ConnectionManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
cfg *v1.ClientCommonConfig
|
|
||||||
|
|
||||||
muxSession *fmux.Session
|
|
||||||
quicConn quic.Connection
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *ConnectionManager {
|
|
||||||
return &ConnectionManager{
|
|
||||||
ctx: ctx,
|
|
||||||
cfg: cfg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenConnection opens a underlying connection to the server.
|
|
||||||
// The underlying connection is either a TCP connection or a QUIC connection.
|
|
||||||
// After the underlying connection is established, you can call Connect() to get a stream.
|
|
||||||
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
|
|
||||||
func (cm *ConnectionManager) OpenConnection() error {
|
|
||||||
xl := xlog.FromContextSafe(cm.ctx)
|
|
||||||
|
|
||||||
// special for quic
|
|
||||||
if strings.EqualFold(cm.cfg.Transport.Protocol, "quic") {
|
|
||||||
var tlsConfig *tls.Config
|
|
||||||
var err error
|
|
||||||
sn := cm.cfg.Transport.TLS.ServerName
|
|
||||||
if sn == "" {
|
|
||||||
sn = cm.cfg.ServerAddr
|
|
||||||
}
|
|
||||||
if lo.FromPtr(cm.cfg.Transport.TLS.Enable) {
|
|
||||||
tlsConfig, err = transport.NewClientTLSConfig(
|
|
||||||
cm.cfg.Transport.TLS.CertFile,
|
|
||||||
cm.cfg.Transport.TLS.KeyFile,
|
|
||||||
cm.cfg.Transport.TLS.TrustedCaFile,
|
|
||||||
sn)
|
|
||||||
} else {
|
|
||||||
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
xl.Warn("fail to build tls configuration, err: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
tlsConfig.NextProtos = []string{"frp"}
|
|
||||||
|
|
||||||
conn, err := quic.DialAddr(
|
|
||||||
cm.ctx,
|
|
||||||
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
|
||||||
tlsConfig, &quic.Config{
|
|
||||||
MaxIdleTimeout: time.Duration(cm.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
|
|
||||||
MaxIncomingStreams: int64(cm.cfg.Transport.QUIC.MaxIncomingStreams),
|
|
||||||
KeepAlivePeriod: time.Duration(cm.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cm.quicConn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !lo.FromPtr(cm.cfg.Transport.TCPMux) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := cm.realConnect()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmuxCfg := fmux.DefaultConfig()
|
|
||||||
fmuxCfg.KeepAliveInterval = time.Duration(cm.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
|
||||||
fmuxCfg.LogOutput = io.Discard
|
|
||||||
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
|
|
||||||
session, err := fmux.Client(conn, fmuxCfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cm.muxSession = session
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
|
|
||||||
func (cm *ConnectionManager) Connect() (net.Conn, error) {
|
|
||||||
if cm.quicConn != nil {
|
|
||||||
stream, err := cm.quicConn.OpenStreamSync(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return utilnet.QuicStreamToNetConn(stream, cm.quicConn), nil
|
|
||||||
} else if cm.muxSession != nil {
|
|
||||||
stream, err := cm.muxSession.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return stream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return cm.realConnect()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
|
||||||
xl := xlog.FromContextSafe(cm.ctx)
|
|
||||||
var tlsConfig *tls.Config
|
|
||||||
var err error
|
|
||||||
tlsEnable := lo.FromPtr(cm.cfg.Transport.TLS.Enable)
|
|
||||||
if cm.cfg.Transport.Protocol == "wss" {
|
|
||||||
tlsEnable = true
|
|
||||||
}
|
|
||||||
if tlsEnable {
|
|
||||||
sn := cm.cfg.Transport.TLS.ServerName
|
|
||||||
if sn == "" {
|
|
||||||
sn = cm.cfg.ServerAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig, err = transport.NewClientTLSConfig(
|
|
||||||
cm.cfg.Transport.TLS.CertFile,
|
|
||||||
cm.cfg.Transport.TLS.KeyFile,
|
|
||||||
cm.cfg.Transport.TLS.TrustedCaFile,
|
|
||||||
sn)
|
|
||||||
if err != nil {
|
|
||||||
xl.Warn("fail to build tls configuration, err: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyType, addr, auth, err := libdial.ParseProxyURL(cm.cfg.Transport.ProxyURL)
|
|
||||||
if err != nil {
|
|
||||||
xl.Error("fail to parse proxy url")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dialOptions := []libdial.DialOption{}
|
|
||||||
protocol := cm.cfg.Transport.Protocol
|
|
||||||
switch protocol {
|
|
||||||
case "websocket":
|
|
||||||
protocol = "tcp"
|
|
||||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
|
|
||||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
|
||||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
|
||||||
}))
|
|
||||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
|
||||||
case "wss":
|
|
||||||
protocol = "tcp"
|
|
||||||
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
|
|
||||||
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
|
||||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
|
|
||||||
default:
|
|
||||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
|
||||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
|
||||||
}))
|
|
||||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
|
||||||
}
|
|
||||||
|
|
||||||
if cm.cfg.Transport.ConnectServerLocalIP != "" {
|
|
||||||
dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.Transport.ConnectServerLocalIP))
|
|
||||||
}
|
|
||||||
dialOptions = append(dialOptions,
|
|
||||||
libdial.WithProtocol(protocol),
|
|
||||||
libdial.WithTimeout(time.Duration(cm.cfg.Transport.DialServerTimeout)*time.Second),
|
|
||||||
libdial.WithKeepAlive(time.Duration(cm.cfg.Transport.DialServerKeepAlive)*time.Second),
|
|
||||||
libdial.WithProxy(proxyType, addr),
|
|
||||||
libdial.WithProxyAuth(auth),
|
|
||||||
)
|
|
||||||
conn, err := libdial.DialContext(
|
|
||||||
cm.ctx,
|
|
||||||
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
|
||||||
dialOptions...,
|
|
||||||
)
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConnectionManager) Close() error {
|
|
||||||
if cm.quicConn != nil {
|
|
||||||
_ = cm.quicConn.CloseWithError(0, "")
|
|
||||||
}
|
|
||||||
if cm.muxSession != nil {
|
|
||||||
_ = cm.muxSession.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/config"
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/config/v1/validation"
|
"github.com/fatedier/frp/pkg/config/v1/validation"
|
||||||
)
|
)
|
||||||
|
@ -50,8 +51,8 @@ func init() {
|
||||||
}
|
}
|
||||||
clientCfg := v1.ClientCommonConfig{}
|
clientCfg := v1.ClientCommonConfig{}
|
||||||
cmd := NewProxyCommand(string(typ), c, &clientCfg)
|
cmd := NewProxyCommand(string(typ), c, &clientCfg)
|
||||||
RegisterClientCommonConfigFlags(cmd, &clientCfg)
|
config.RegisterClientCommonConfigFlags(cmd, &clientCfg)
|
||||||
RegisterProxyFlags(cmd, c)
|
config.RegisterProxyFlags(cmd, c)
|
||||||
|
|
||||||
// add sub command for visitor
|
// add sub command for visitor
|
||||||
if lo.Contains(visitorTypes, v1.VisitorType(typ)) {
|
if lo.Contains(visitorTypes, v1.VisitorType(typ)) {
|
||||||
|
@ -60,7 +61,7 @@ func init() {
|
||||||
panic("visitor type: " + typ + " not support")
|
panic("visitor type: " + typ + " not support")
|
||||||
}
|
}
|
||||||
visitorCmd := NewVisitorCommand(string(typ), vc, &clientCfg)
|
visitorCmd := NewVisitorCommand(string(typ), vc, &clientCfg)
|
||||||
RegisterVisitorFlags(visitorCmd, vc)
|
config.RegisterVisitorFlags(visitorCmd, vc)
|
||||||
cmd.AddCommand(visitorCmd)
|
cmd.AddCommand(visitorCmd)
|
||||||
}
|
}
|
||||||
rootCmd.AddCommand(cmd)
|
rootCmd.AddCommand(cmd)
|
||||||
|
|
|
@ -1,110 +0,0 @@
|
||||||
// Copyright 2023 The frp Authors
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config/types"
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PortsRangeSliceFlag struct {
|
|
||||||
V *[]types.PortsRange
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *PortsRangeSliceFlag) String() string {
|
|
||||||
if f.V == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return types.PortsRangeSlice(*f.V).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *PortsRangeSliceFlag) Set(s string) error {
|
|
||||||
slice, err := types.NewPortsRangeSliceFromString(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*f.V = slice
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *PortsRangeSliceFlag) Type() string {
|
|
||||||
return "string"
|
|
||||||
}
|
|
||||||
|
|
||||||
type BoolFuncFlag struct {
|
|
||||||
TrueFunc func()
|
|
||||||
FalseFunc func()
|
|
||||||
|
|
||||||
v bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *BoolFuncFlag) String() string {
|
|
||||||
return strconv.FormatBool(f.v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *BoolFuncFlag) Set(s string) error {
|
|
||||||
f.v = strconv.FormatBool(f.v) == "true"
|
|
||||||
|
|
||||||
if !f.v {
|
|
||||||
if f.FalseFunc != nil {
|
|
||||||
f.FalseFunc()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.TrueFunc != nil {
|
|
||||||
f.TrueFunc()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *BoolFuncFlag) Type() string {
|
|
||||||
return "bool"
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) {
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address")
|
|
||||||
cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port")
|
|
||||||
cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.ProxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address")
|
|
||||||
cmd.PersistentFlags().IntVarP(&c.VhostHTTPPort, "vhost_http_port", "", 0, "vhost http port")
|
|
||||||
cmd.PersistentFlags().IntVarP(&c.VhostHTTPSPort, "vhost_https_port", "", 0, "vhost https port")
|
|
||||||
cmd.PersistentFlags().Int64VarP(&c.VhostHTTPTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.WebServer.Addr, "dashboard_addr", "", "0.0.0.0", "dashboard address")
|
|
||||||
cmd.PersistentFlags().IntVarP(&c.WebServer.Port, "dashboard_port", "", 0, "dashboard port")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.WebServer.User, "dashboard_user", "", "admin", "dashboard user")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.WebServer.Password, "dashboard_pwd", "", "admin", "dashboard password")
|
|
||||||
cmd.PersistentFlags().BoolVarP(&c.EnablePrometheus, "enable_prometheus", "", false, "enable prometheus dashboard")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "log file")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
|
|
||||||
cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log max days")
|
|
||||||
cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token")
|
|
||||||
cmd.PersistentFlags().StringVarP(&c.SubDomainHost, "subdomain_host", "", "", "subdomain host")
|
|
||||||
cmd.PersistentFlags().VarP(&PortsRangeSliceFlag{V: &c.AllowPorts}, "allow_ports", "", "allow ports")
|
|
||||||
cmd.PersistentFlags().Int64VarP(&c.MaxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client")
|
|
||||||
cmd.PersistentFlags().BoolVarP(&c.Transport.TLS.Force, "tls_only", "", false, "frps tls only")
|
|
||||||
|
|
||||||
webServerTLS := v1.TLSConfig{}
|
|
||||||
cmd.PersistentFlags().StringVarP(&webServerTLS.CertFile, "dashboard_tls_cert_file", "", "", "dashboard tls cert file")
|
|
||||||
cmd.PersistentFlags().StringVarP(&webServerTLS.KeyFile, "dashboard_tls_key_file", "", "", "dashboard tls key file")
|
|
||||||
cmd.PersistentFlags().VarP(&BoolFuncFlag{
|
|
||||||
TrueFunc: func() { c.WebServer.TLS = &webServerTLS },
|
|
||||||
}, "dashboard_tls_mode", "", "if enable dashboard tls mode")
|
|
||||||
}
|
|
|
@ -42,7 +42,7 @@ func init() {
|
||||||
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps")
|
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps")
|
||||||
rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", false, "strict config parsing mode, unknown fields will cause error")
|
rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", false, "strict config parsing mode, unknown fields will cause error")
|
||||||
|
|
||||||
RegisterServerConfigFlags(rootCmd, &serverCfg)
|
config.RegisterServerConfigFlags(rootCmd, &serverCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -21,7 +21,7 @@ require (
|
||||||
github.com/quic-go/quic-go v0.37.4
|
github.com/quic-go/quic-go v0.37.4
|
||||||
github.com/rodaine/table v1.1.0
|
github.com/rodaine/table v1.1.0
|
||||||
github.com/samber/lo v1.38.1
|
github.com/samber/lo v1.38.1
|
||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.8.0
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
golang.org/x/crypto v0.15.0
|
golang.org/x/crypto v0.15.0
|
||||||
golang.org/x/net v0.17.0
|
golang.org/x/net v0.17.0
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -16,7 +16,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
|
||||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||||
github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o=
|
github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o=
|
||||||
github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc=
|
github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
@ -128,8 +128,8 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
|
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
|
||||||
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
|
||||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
|
|
@ -12,10 +12,11 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package sub
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
@ -123,3 +124,89 @@ func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfi
|
||||||
|
|
||||||
c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls")
|
c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PortsRangeSliceFlag struct {
|
||||||
|
V *[]types.PortsRange
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *PortsRangeSliceFlag) String() string {
|
||||||
|
if f.V == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return types.PortsRangeSlice(*f.V).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *PortsRangeSliceFlag) Set(s string) error {
|
||||||
|
slice, err := types.NewPortsRangeSliceFromString(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*f.V = slice
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *PortsRangeSliceFlag) Type() string {
|
||||||
|
return "string"
|
||||||
|
}
|
||||||
|
|
||||||
|
type BoolFuncFlag struct {
|
||||||
|
TrueFunc func()
|
||||||
|
FalseFunc func()
|
||||||
|
|
||||||
|
v bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *BoolFuncFlag) String() string {
|
||||||
|
return strconv.FormatBool(f.v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *BoolFuncFlag) Set(s string) error {
|
||||||
|
f.v = strconv.FormatBool(f.v) == "true"
|
||||||
|
|
||||||
|
if !f.v {
|
||||||
|
if f.FalseFunc != nil {
|
||||||
|
f.FalseFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.TrueFunc != nil {
|
||||||
|
f.TrueFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *BoolFuncFlag) Type() string {
|
||||||
|
return "bool"
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) {
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address")
|
||||||
|
cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port")
|
||||||
|
cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.ProxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address")
|
||||||
|
cmd.PersistentFlags().IntVarP(&c.VhostHTTPPort, "vhost_http_port", "", 0, "vhost http port")
|
||||||
|
cmd.PersistentFlags().IntVarP(&c.VhostHTTPSPort, "vhost_https_port", "", 0, "vhost https port")
|
||||||
|
cmd.PersistentFlags().Int64VarP(&c.VhostHTTPTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.WebServer.Addr, "dashboard_addr", "", "0.0.0.0", "dashboard address")
|
||||||
|
cmd.PersistentFlags().IntVarP(&c.WebServer.Port, "dashboard_port", "", 0, "dashboard port")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.WebServer.User, "dashboard_user", "", "admin", "dashboard user")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.WebServer.Password, "dashboard_pwd", "", "admin", "dashboard password")
|
||||||
|
cmd.PersistentFlags().BoolVarP(&c.EnablePrometheus, "enable_prometheus", "", false, "enable prometheus dashboard")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "log file")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
|
||||||
|
cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log max days")
|
||||||
|
cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token")
|
||||||
|
cmd.PersistentFlags().StringVarP(&c.SubDomainHost, "subdomain_host", "", "", "subdomain host")
|
||||||
|
cmd.PersistentFlags().VarP(&PortsRangeSliceFlag{V: &c.AllowPorts}, "allow_ports", "", "allow ports")
|
||||||
|
cmd.PersistentFlags().Int64VarP(&c.MaxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client")
|
||||||
|
cmd.PersistentFlags().BoolVarP(&c.Transport.TLS.Force, "tls_only", "", false, "frps tls only")
|
||||||
|
|
||||||
|
webServerTLS := v1.TLSConfig{}
|
||||||
|
cmd.PersistentFlags().StringVarP(&webServerTLS.CertFile, "dashboard_tls_cert_file", "", "", "dashboard tls cert file")
|
||||||
|
cmd.PersistentFlags().StringVarP(&webServerTLS.KeyFile, "dashboard_tls_key_file", "", "", "dashboard tls key file")
|
||||||
|
cmd.PersistentFlags().VarP(&BoolFuncFlag{
|
||||||
|
TrueFunc: func() { c.WebServer.TLS = &webServerTLS },
|
||||||
|
}, "dashboard_tls_mode", "", "if enable dashboard tls mode")
|
||||||
|
}
|
|
@ -16,21 +16,11 @@ package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config/types"
|
"github.com/fatedier/frp/pkg/config/types"
|
||||||
"github.com/fatedier/frp/pkg/util/util"
|
"github.com/fatedier/frp/pkg/util/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSHTunnelGateway struct {
|
|
||||||
BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"`
|
|
||||||
PrivateKeyFilePath string `json:"privateKeyFilePath,omitempty"`
|
|
||||||
PublicKeyFilesPath string `json:"publicKeyFilesPath,omitempty"`
|
|
||||||
|
|
||||||
// store all public key file. load all when init
|
|
||||||
PublicKeyFilesMap map[string]ssh.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
APIMetadata
|
APIMetadata
|
||||||
|
|
||||||
|
@ -41,9 +31,6 @@ type ServerConfig struct {
|
||||||
// BindPort specifies the port that the server listens on. By default, this
|
// BindPort specifies the port that the server listens on. By default, this
|
||||||
// value is 7000.
|
// value is 7000.
|
||||||
BindPort int `json:"bindPort,omitempty"`
|
BindPort int `json:"bindPort,omitempty"`
|
||||||
|
|
||||||
SSHTunnelGateway SSHTunnelGateway `json:"sshGatewayConfig,omitempty"`
|
|
||||||
|
|
||||||
// KCPBindPort specifies the KCP port that the server listens on. If this
|
// KCPBindPort specifies the KCP port that the server listens on. If this
|
||||||
// value is 0, the server will not listen for KCP connections.
|
// value is 0, the server will not listen for KCP connections.
|
||||||
KCPBindPort int `json:"kcpBindPort,omitempty"`
|
KCPBindPort int `json:"kcpBindPort,omitempty"`
|
||||||
|
@ -80,6 +67,8 @@ type ServerConfig struct {
|
||||||
// value is "", a default page will be displayed.
|
// value is "", a default page will be displayed.
|
||||||
Custom404Page string `json:"custom404Page,omitempty"`
|
Custom404Page string `json:"custom404Page,omitempty"`
|
||||||
|
|
||||||
|
SSHTunnelGateway SSHTunnelGateway `json:"sshTunnelGateway,omitempty"`
|
||||||
|
|
||||||
WebServer WebServerConfig `json:"webServer,omitempty"`
|
WebServer WebServerConfig `json:"webServer,omitempty"`
|
||||||
// EnablePrometheus will export prometheus metrics on webserver address
|
// EnablePrometheus will export prometheus metrics on webserver address
|
||||||
// in /metrics api.
|
// in /metrics api.
|
||||||
|
@ -114,6 +103,7 @@ func (c *ServerConfig) Complete() {
|
||||||
c.Log.Complete()
|
c.Log.Complete()
|
||||||
c.Transport.Complete()
|
c.Transport.Complete()
|
||||||
c.WebServer.Complete()
|
c.WebServer.Complete()
|
||||||
|
c.SSHTunnelGateway.Complete()
|
||||||
|
|
||||||
c.BindAddr = util.EmptyOr(c.BindAddr, "0.0.0.0")
|
c.BindAddr = util.EmptyOr(c.BindAddr, "0.0.0.0")
|
||||||
c.BindPort = util.EmptyOr(c.BindPort, 7000)
|
c.BindPort = util.EmptyOr(c.BindPort, 7000)
|
||||||
|
@ -202,3 +192,14 @@ type TLSServerConfig struct {
|
||||||
|
|
||||||
TLSConfig
|
TLSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SSHTunnelGateway struct {
|
||||||
|
BindPort int `json:"bindPort,omitempty"`
|
||||||
|
PrivateKeyFile string `json:"privateKeyFile,omitempty"`
|
||||||
|
AutoGenPrivateKeyPath string `json:"autoGenPrivateKeyPath,omitempty"`
|
||||||
|
AuthorizedKeysFile string `json:"authorizedKeysFile,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSHTunnelGateway) Complete() {
|
||||||
|
c.AutoGenPrivateKeyPath = util.EmptyOr(c.AutoGenPrivateKeyPath, "./.autogen_ssh_key")
|
||||||
|
}
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
package v1
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// custom define
|
|
||||||
SSHClientLoginUserPrefix = "_frpc_ssh_client_"
|
|
||||||
)
|
|
||||||
|
|
||||||
// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format
|
|
||||||
func GeneratePrivateKey() ([]byte, error) {
|
|
||||||
privateKey, err := generatePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("gen private key error")
|
|
||||||
}
|
|
||||||
|
|
||||||
privBlock := pem.Block{
|
|
||||||
Type: "RSA PRIVATE KEY",
|
|
||||||
Headers: nil,
|
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
|
||||||
}
|
|
||||||
|
|
||||||
return pem.EncodeToMemory(&privBlock), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generatePrivateKey creates a RSA Private Key of specified byte size
|
|
||||||
func generatePrivateKey() (*rsa.PrivateKey, error) {
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = privateKey.Validate()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return privateKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadSSHPublicKeyFilesInDir(dirPath string) (map[string]ssh.PublicKey, error) {
|
|
||||||
fileMap := make(map[string]ssh.PublicKey)
|
|
||||||
files, err := os.ReadDir(dirPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, file := range files {
|
|
||||||
filePath := filepath.Join(dirPath, file.Name())
|
|
||||||
content, err := os.ReadFile(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(content)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fileMap[ssh.FingerprintSHA256(parsedAuthorizedKey)] = parsedAuthorizedKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return fileMap, nil
|
|
||||||
}
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
// Copyright 2023 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
|
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Gateway struct {
|
||||||
|
bindPort int
|
||||||
|
ln net.Listener
|
||||||
|
|
||||||
|
serverPeerListener *utilnet.InternalListener
|
||||||
|
|
||||||
|
sshConfig *ssh.ServerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGateway(
|
||||||
|
cfg v1.SSHTunnelGateway, bindAddr string,
|
||||||
|
serverPeerListener *utilnet.InternalListener,
|
||||||
|
) (*Gateway, error) {
|
||||||
|
sshConfig := &ssh.ServerConfig{}
|
||||||
|
|
||||||
|
// privateKey
|
||||||
|
var (
|
||||||
|
privateKeyBytes []byte
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if cfg.PrivateKeyFile != "" {
|
||||||
|
privateKeyBytes, err = os.ReadFile(cfg.PrivateKeyFile)
|
||||||
|
} else {
|
||||||
|
if cfg.AutoGenPrivateKeyPath != "" {
|
||||||
|
privateKeyBytes, _ = os.ReadFile(cfg.AutoGenPrivateKeyPath)
|
||||||
|
}
|
||||||
|
if len(privateKeyBytes) == 0 {
|
||||||
|
privateKeyBytes, err = transport.NewRandomPrivateKey()
|
||||||
|
if err == nil && cfg.AutoGenPrivateKeyPath != "" {
|
||||||
|
err = os.WriteFile(cfg.AutoGenPrivateKeyPath, privateKeyBytes, 0o600)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sshConfig.AddHostKey(privateKey)
|
||||||
|
|
||||||
|
sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
if cfg.AuthorizedKeysFile == "" {
|
||||||
|
return &ssh.Permissions{
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"user": "",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("internal error")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, ok := authorizedKeysMap[string(key.Marshal())]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown public key for remoteAddr %q", conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
return &ssh.Permissions{
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"user": user,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(cfg.BindPort)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Gateway{
|
||||||
|
bindPort: cfg.BindPort,
|
||||||
|
ln: ln,
|
||||||
|
serverPeerListener: serverPeerListener,
|
||||||
|
sshConfig: sshConfig,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) Run() {
|
||||||
|
for {
|
||||||
|
conn, err := g.ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go g.handleConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) handleConn(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ts.Run(); err != nil {
|
||||||
|
log.Error("ssh tunnel server run error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadAuthorizedKeysFromFile(path string) (map[string]string, error) {
|
||||||
|
authorizedKeysMap := make(map[string]string) // value is username
|
||||||
|
authorizedKeysBytes, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for len(authorizedKeysBytes) > 0 {
|
||||||
|
pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authorizedKeysMap[string(pubKey.Marshal())] = strings.TrimSpace(comment)
|
||||||
|
authorizedKeysBytes = rest
|
||||||
|
}
|
||||||
|
return authorizedKeysMap, nil
|
||||||
|
}
|
|
@ -0,0 +1,279 @@
|
||||||
|
// Copyright 2023 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
libio "github.com/fatedier/golib/io"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/config"
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
"github.com/fatedier/frp/pkg/util/util"
|
||||||
|
"github.com/fatedier/frp/pkg/util/xlog"
|
||||||
|
"github.com/fatedier/frp/pkg/virtual"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||||
|
ChannelTypeServerOpenChannel = "forwarded-tcpip"
|
||||||
|
RequestTypeForward = "tcpip-forward"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tcpipForward struct {
|
||||||
|
Host string
|
||||||
|
Port uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||||
|
type forwardedTCPPayload struct {
|
||||||
|
Addr string
|
||||||
|
Port uint32
|
||||||
|
|
||||||
|
// can be default empty value but do not delete it
|
||||||
|
// because ssh protocol shoule be reserved
|
||||||
|
OriginAddr string
|
||||||
|
OriginPort uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelServer struct {
|
||||||
|
underlyingConn net.Conn
|
||||||
|
sshConn *ssh.ServerConn
|
||||||
|
sc *ssh.ServerConfig
|
||||||
|
|
||||||
|
vc *virtual.Client
|
||||||
|
serverPeerListener *utilnet.InternalListener
|
||||||
|
doneCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) {
|
||||||
|
s := &TunnelServer{
|
||||||
|
underlyingConn: conn,
|
||||||
|
sc: sc,
|
||||||
|
serverPeerListener: serverPeerListener,
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) Run() error {
|
||||||
|
sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.sshConn = sshConn
|
||||||
|
|
||||||
|
addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
|
||||||
|
pc.Complete(clientCfg.User)
|
||||||
|
|
||||||
|
s.vc = virtual.NewClient(clientCfg)
|
||||||
|
// join workConn and ssh channel
|
||||||
|
s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
|
||||||
|
c, err := s.openConn(addr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
libio.Join(c, workConn)
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
// transfer connection from virtual client to server peer listener
|
||||||
|
go func() {
|
||||||
|
l := s.vc.PeerListener()
|
||||||
|
for {
|
||||||
|
conn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.serverPeerListener.PutConn(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
|
||||||
|
ctx := xlog.NewContext(context.Background(), xl)
|
||||||
|
go func() {
|
||||||
|
_ = s.vc.Run(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
|
||||||
|
|
||||||
|
_ = sshConn.Wait()
|
||||||
|
_ = sshConn.Close()
|
||||||
|
s.vc.Close()
|
||||||
|
close(s.doneCh)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) waitForwardAddrAndExtraPayload(
|
||||||
|
channels <-chan ssh.NewChannel,
|
||||||
|
requests <-chan *ssh.Request,
|
||||||
|
timeout time.Duration,
|
||||||
|
) (*tcpipForward, string, error) {
|
||||||
|
addrCh := make(chan *tcpipForward, 1)
|
||||||
|
extraPayloadCh := make(chan string, 1)
|
||||||
|
|
||||||
|
// get forward address
|
||||||
|
go func() {
|
||||||
|
addrGot := false
|
||||||
|
for req := range requests {
|
||||||
|
switch req.Type {
|
||||||
|
case RequestTypeForward:
|
||||||
|
if !addrGot {
|
||||||
|
payload := tcpipForward{}
|
||||||
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addrGot = true
|
||||||
|
addrCh <- &payload
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if req.WantReply {
|
||||||
|
_ = req.Reply(true, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// get extra payload
|
||||||
|
go func() {
|
||||||
|
for newChannel := range channels {
|
||||||
|
// extraPayload will send to extraPayloadCh
|
||||||
|
go s.handleNewChannel(newChannel, extraPayloadCh)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr *tcpipForward
|
||||||
|
extraPayload string
|
||||||
|
)
|
||||||
|
|
||||||
|
timer := time.NewTimer(timeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case v := <-addrCh:
|
||||||
|
addr = v
|
||||||
|
case extra := <-extraPayloadCh:
|
||||||
|
extraPayload = extra
|
||||||
|
case <-timer.C:
|
||||||
|
return nil, "", fmt.Errorf("get addr and extra payload timeout")
|
||||||
|
}
|
||||||
|
if addr != nil && extraPayload != "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addr, extraPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) {
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
args := strings.Split(extraPayload, " ")
|
||||||
|
if len(args) < 1 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid extra payload")
|
||||||
|
}
|
||||||
|
proxyType := strings.TrimSpace(args[0])
|
||||||
|
supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"}
|
||||||
|
if !lo.Contains(supportTypes, proxyType) {
|
||||||
|
return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
|
||||||
|
}
|
||||||
|
pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType))
|
||||||
|
if pc == nil {
|
||||||
|
return nil, nil, fmt.Errorf("new proxy configurer error")
|
||||||
|
}
|
||||||
|
config.RegisterProxyFlags(cmd, pc)
|
||||||
|
|
||||||
|
clientCfg := v1.ClientCommonConfig{}
|
||||||
|
config.RegisterClientCommonConfigFlags(cmd, &clientCfg)
|
||||||
|
|
||||||
|
if err := cmd.ParseFlags(args); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
|
||||||
|
}
|
||||||
|
return &clientCfg, pc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) {
|
||||||
|
ch, reqs, err := channel.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go s.keepAlive(ch)
|
||||||
|
|
||||||
|
for req := range reqs {
|
||||||
|
if req.Type != "exec" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(req.Payload) <= 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
|
||||||
|
if len(req.Payload) < int(end) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
extraPayload := string(req.Payload[4:end])
|
||||||
|
select {
|
||||||
|
case extraPayloadCh <- extraPayload:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) keepAlive(ch ssh.Channel) {
|
||||||
|
tk := time.NewTicker(time.Second * 30)
|
||||||
|
defer tk.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-tk.C:
|
||||||
|
_, err := ch.SendRequest("heartbeat", false, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-s.doneCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
|
||||||
|
payload := forwardedTCPPayload{
|
||||||
|
Addr: addr.Host,
|
||||||
|
Port: addr.Port,
|
||||||
|
}
|
||||||
|
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open ssh channel error: %v", err)
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn)
|
||||||
|
return conn, nil
|
||||||
|
}
|
|
@ -1,497 +0,0 @@
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
gerror "github.com/fatedier/golib/errors"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ssh protocol define
|
|
||||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
|
||||||
ChannelTypeServerOpenChannel = "forwarded-tcpip"
|
|
||||||
RequestTypeForward = "tcpip-forward"
|
|
||||||
|
|
||||||
// golang ssh package define.
|
|
||||||
// https://pkg.go.dev/golang.org/x/crypto/ssh
|
|
||||||
RequestTypeHeartbeat = "keepalive@openssh.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 当 proxy 失败会返回该错误
|
|
||||||
type VProxyError struct{}
|
|
||||||
|
|
||||||
// ssh protocol define
|
|
||||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
|
||||||
// parse ssh client cmds input
|
|
||||||
type forwardedTCPPayload struct {
|
|
||||||
Addr string
|
|
||||||
Port uint32
|
|
||||||
|
|
||||||
// can be default empty value but do not delete it
|
|
||||||
// because ssh protocol shoule be reserved
|
|
||||||
OriginAddr string
|
|
||||||
OriginPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// custom define
|
|
||||||
// parse ssh client cmds input
|
|
||||||
type CmdPayload struct {
|
|
||||||
Address string
|
|
||||||
Port uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// custom define
|
|
||||||
// with frp control cmds
|
|
||||||
type ExtraPayload struct {
|
|
||||||
Type string
|
|
||||||
|
|
||||||
// TODO port can be set by extra message and priority to ssh raw cmd
|
|
||||||
Address string
|
|
||||||
Port uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type Service struct {
|
|
||||||
tcpConn net.Conn
|
|
||||||
cfg *ssh.ServerConfig
|
|
||||||
|
|
||||||
sshConn *ssh.ServerConn
|
|
||||||
gChannel <-chan ssh.NewChannel
|
|
||||||
gReq <-chan *ssh.Request
|
|
||||||
|
|
||||||
addrPayloadCh chan CmdPayload
|
|
||||||
extraPayloadCh chan ExtraPayload
|
|
||||||
|
|
||||||
proxyPayloadCh chan v1.ProxyConfigurer
|
|
||||||
replyCh chan interface{}
|
|
||||||
|
|
||||||
closeCh chan struct{}
|
|
||||||
exit int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSSHService(
|
|
||||||
tcpConn net.Conn,
|
|
||||||
cfg *ssh.ServerConfig,
|
|
||||||
proxyPayloadCh chan v1.ProxyConfigurer,
|
|
||||||
replyCh chan interface{},
|
|
||||||
) (ss *Service, err error) {
|
|
||||||
ss = &Service{
|
|
||||||
tcpConn: tcpConn,
|
|
||||||
cfg: cfg,
|
|
||||||
|
|
||||||
addrPayloadCh: make(chan CmdPayload),
|
|
||||||
extraPayloadCh: make(chan ExtraPayload),
|
|
||||||
|
|
||||||
proxyPayloadCh: proxyPayloadCh,
|
|
||||||
replyCh: replyCh,
|
|
||||||
|
|
||||||
closeCh: make(chan struct{}),
|
|
||||||
exit: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("ssh handshake error: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("ssh connection success")
|
|
||||||
|
|
||||||
return ss, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) Run() {
|
|
||||||
go ss.loopGenerateProxy()
|
|
||||||
go ss.loopParseCmdPayload()
|
|
||||||
go ss.loopParseExtraPayload()
|
|
||||||
go ss.loopReply()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) Exit() <-chan struct{} {
|
|
||||||
return ss.closeCh
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) Close() {
|
|
||||||
if atomic.LoadInt32(&ss.exit) == 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ss.closeCh:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
close(ss.closeCh)
|
|
||||||
close(ss.addrPayloadCh)
|
|
||||||
close(ss.extraPayloadCh)
|
|
||||||
|
|
||||||
_ = ss.sshConn.Wait()
|
|
||||||
|
|
||||||
ss.sshConn.Close()
|
|
||||||
ss.tcpConn.Close()
|
|
||||||
|
|
||||||
atomic.StoreInt32(&ss.exit, 1)
|
|
||||||
|
|
||||||
log.Info("ssh service close")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) loopParseCmdPayload() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case req, ok := <-ss.gReq:
|
|
||||||
if !ok {
|
|
||||||
log.Info("global request is close")
|
|
||||||
ss.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch req.Type {
|
|
||||||
case RequestTypeForward:
|
|
||||||
var addrPayload CmdPayload
|
|
||||||
if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil {
|
|
||||||
log.Error("ssh unmarshal error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = gerror.PanicToError(func() {
|
|
||||||
ss.addrPayloadCh <- addrPayload
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
if req.Type == RequestTypeHeartbeat {
|
|
||||||
log.Debug("ssh heartbeat data")
|
|
||||||
} else {
|
|
||||||
log.Info("default req, data: %v", req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if req.WantReply {
|
|
||||||
err := req.Reply(true, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("reply to ssh client error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case <-ss.closeCh:
|
|
||||||
log.Info("loop parse cmd payload close")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) loopSendHeartbeat(ch ssh.Channel) {
|
|
||||||
tk := time.NewTicker(time.Second * 60)
|
|
||||||
defer tk.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-tk.C:
|
|
||||||
ok, err := ch.SendRequest("heartbeat", false, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("channel send req error: %v", err)
|
|
||||||
if err == io.EOF {
|
|
||||||
ss.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Debug("heartbeat send success, ok: %v", ok)
|
|
||||||
case <-ss.closeCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) loopParseExtraPayload() {
|
|
||||||
log.Info("loop parse extra payload start")
|
|
||||||
|
|
||||||
for newChannel := range ss.gChannel {
|
|
||||||
ch, req, err := newChannel.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("channel accept error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go ss.loopSendHeartbeat(ch)
|
|
||||||
|
|
||||||
go func(req <-chan *ssh.Request) {
|
|
||||||
for r := range req {
|
|
||||||
if len(r.Payload) <= 4 {
|
|
||||||
log.Info("r.payload is less than 4")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") {
|
|
||||||
log.Info("ssh protocol exchange data")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// [4byte data_len|data]
|
|
||||||
end := 4 + binary.BigEndian.Uint32(r.Payload[:4])
|
|
||||||
if end > uint32(len(r.Payload)) {
|
|
||||||
end = uint32(len(r.Payload))
|
|
||||||
}
|
|
||||||
p := string(r.Payload[4:end])
|
|
||||||
|
|
||||||
msg, err := parseSSHExtraMessage(p)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
_ = gerror.PanicToError(func() {
|
|
||||||
ss.extraPayloadCh <- msg
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}(req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) SSHConn() *ssh.ServerConn {
|
|
||||||
return ss.sshConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) TCPConn() net.Conn {
|
|
||||||
return ss.tcpConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) loopReply() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ss.closeCh:
|
|
||||||
log.Info("loop reply close")
|
|
||||||
return
|
|
||||||
case req := <-ss.replyCh:
|
|
||||||
switch req.(type) {
|
|
||||||
case *VProxyError:
|
|
||||||
log.Error("run frp proxy error, close ssh service")
|
|
||||||
ss.Close()
|
|
||||||
default:
|
|
||||||
// TODO
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss *Service) loopGenerateProxy() {
|
|
||||||
log.Info("loop generate proxy start")
|
|
||||||
|
|
||||||
for {
|
|
||||||
if atomic.LoadInt32(&ss.exit) == 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := new(sync.WaitGroup)
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
var p1 CmdPayload
|
|
||||||
var p2 ExtraPayload
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ss.closeCh:
|
|
||||||
return
|
|
||||||
case p1 = <-ss.addrPayloadCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ss.closeCh:
|
|
||||||
return
|
|
||||||
case p2 = <-ss.extraPayloadCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if atomic.LoadInt32(&ss.exit) == 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch p2.Type {
|
|
||||||
case "http":
|
|
||||||
case "tcp":
|
|
||||||
ss.proxyPayloadCh <- &v1.TCPProxyConfig{
|
|
||||||
ProxyBaseConfig: v1.ProxyBaseConfig{
|
|
||||||
Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()),
|
|
||||||
Type: p2.Type,
|
|
||||||
|
|
||||||
ProxyBackend: v1.ProxyBackend{
|
|
||||||
LocalIP: p1.Address,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RemotePort: int(p1.Port),
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Warn("invalid frp proxy type: %v", p2.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSSHExtraMessage(s string) (p ExtraPayload, err error) {
|
|
||||||
sn := len(s)
|
|
||||||
|
|
||||||
log.Info("parse ssh extra message: %v", s)
|
|
||||||
|
|
||||||
ss := strings.Fields(s)
|
|
||||||
if len(ss) == 0 {
|
|
||||||
if sn != 0 {
|
|
||||||
ss = append(ss, s)
|
|
||||||
} else {
|
|
||||||
return p, fmt.Errorf("invalid ssh input, args: %v", ss)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range ss {
|
|
||||||
ss[i] = strings.TrimSpace(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ss[0] != "tcp" && ss[0] != "http" {
|
|
||||||
return p, fmt.Errorf("only support tcp/http now")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch ss[0] {
|
|
||||||
case "tcp":
|
|
||||||
tcpCmd, err := ParseTCPCommand(ss)
|
|
||||||
if err != nil {
|
|
||||||
return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
port, _ := strconv.Atoi(tcpCmd.Port)
|
|
||||||
|
|
||||||
p = ExtraPayload{
|
|
||||||
Type: "tcp",
|
|
||||||
Address: tcpCmd.Address,
|
|
||||||
Port: uint32(port),
|
|
||||||
}
|
|
||||||
case "http":
|
|
||||||
httpCmd, err := ParseHTTPCommand(ss)
|
|
||||||
if err != nil {
|
|
||||||
return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = httpCmd
|
|
||||||
|
|
||||||
p = ExtraPayload{
|
|
||||||
Type: "http",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type HTTPCommand struct {
|
|
||||||
Domain string
|
|
||||||
BasicAuthUser string
|
|
||||||
BasicAuthPass string
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseHTTPCommand(params []string) (*HTTPCommand, error) {
|
|
||||||
if len(params) < 2 {
|
|
||||||
return nil, errors.New("invalid HTTP command")
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
basicAuth string
|
|
||||||
domainURL string
|
|
||||||
basicAuthUser string
|
|
||||||
basicAuthPass string
|
|
||||||
)
|
|
||||||
|
|
||||||
fs := flag.NewFlagSet("http", flag.ContinueOnError)
|
|
||||||
fs.StringVar(&basicAuth, "basic-auth", "", "")
|
|
||||||
fs.StringVar(&domainURL, "domain", "", "")
|
|
||||||
|
|
||||||
fs.SetOutput(&nullWriter{}) // Disables usage output
|
|
||||||
|
|
||||||
err := fs.Parse(params[2:])
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, flag.ErrHelp) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if basicAuth != "" {
|
|
||||||
authParts := strings.SplitN(basicAuth, ":", 2)
|
|
||||||
basicAuthUser = authParts[0]
|
|
||||||
if len(authParts) > 1 {
|
|
||||||
basicAuthPass = authParts[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
httpCmd := &HTTPCommand{
|
|
||||||
Domain: domainURL,
|
|
||||||
BasicAuthUser: basicAuthUser,
|
|
||||||
BasicAuthPass: basicAuthPass,
|
|
||||||
}
|
|
||||||
return httpCmd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPCommand struct {
|
|
||||||
Address string
|
|
||||||
Port string
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseTCPCommand(params []string) (*TCPCommand, error) {
|
|
||||||
if len(params) == 0 || params[0] != "tcp" {
|
|
||||||
return nil, errors.New("invalid TCP command")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(params) == 1 {
|
|
||||||
return &TCPCommand{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
address string
|
|
||||||
port string
|
|
||||||
)
|
|
||||||
|
|
||||||
fs := flag.NewFlagSet("tcp", flag.ContinueOnError)
|
|
||||||
fs.StringVar(&address, "address", "", "The IP address to listen on")
|
|
||||||
fs.StringVar(&port, "port", "", "The port to listen on")
|
|
||||||
fs.SetOutput(&nullWriter{}) // Disables usage output
|
|
||||||
|
|
||||||
args := params[1:]
|
|
||||||
err := fs.Parse(args)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, flag.ErrHelp) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedAddr, err := net.ResolveIPAddr("ip", address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if _, err := net.LookupPort("tcp", port); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpCmd := &TCPCommand{
|
|
||||||
Address: parsedAddr.String(),
|
|
||||||
Port: port,
|
|
||||||
}
|
|
||||||
return tcpCmd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type nullWriter struct{}
|
|
||||||
|
|
||||||
func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil }
|
|
|
@ -1,185 +0,0 @@
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config"
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
||||||
"github.com/fatedier/frp/pkg/msg"
|
|
||||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
|
||||||
frp_net "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
"github.com/fatedier/frp/pkg/util/util"
|
|
||||||
"github.com/fatedier/frp/pkg/util/xlog"
|
|
||||||
"github.com/fatedier/frp/server/controller"
|
|
||||||
"github.com/fatedier/frp/server/proxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
// VirtualService is a client VirtualService run in frps
|
|
||||||
type VirtualService struct {
|
|
||||||
clientCfg v1.ClientCommonConfig
|
|
||||||
pxyCfg v1.ProxyConfigurer
|
|
||||||
serverCfg v1.ServerConfig
|
|
||||||
|
|
||||||
sshSvc *Service
|
|
||||||
|
|
||||||
// uniq id got from frps, attach it in loginMsg
|
|
||||||
runID string
|
|
||||||
loginMsg *msg.Login
|
|
||||||
|
|
||||||
// All resource managers and controllers
|
|
||||||
rc *controller.ResourceController
|
|
||||||
|
|
||||||
exit uint32 // 0 means not exit
|
|
||||||
// SSHService context
|
|
||||||
ctx context.Context
|
|
||||||
// call cancel to stop SSHService
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
replyCh chan interface{}
|
|
||||||
pxy proxy.Proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVirtualService(
|
|
||||||
ctx context.Context,
|
|
||||||
clientCfg v1.ClientCommonConfig,
|
|
||||||
serverCfg v1.ServerConfig,
|
|
||||||
logMsg msg.Login,
|
|
||||||
rc *controller.ResourceController,
|
|
||||||
pxyCfg v1.ProxyConfigurer,
|
|
||||||
sshSvc *Service,
|
|
||||||
replyCh chan interface{},
|
|
||||||
) (svr *VirtualService, err error) {
|
|
||||||
svr = &VirtualService{
|
|
||||||
clientCfg: clientCfg,
|
|
||||||
serverCfg: serverCfg,
|
|
||||||
rc: rc,
|
|
||||||
|
|
||||||
loginMsg: &logMsg,
|
|
||||||
|
|
||||||
sshSvc: sshSvc,
|
|
||||||
pxyCfg: pxyCfg,
|
|
||||||
|
|
||||||
ctx: ctx,
|
|
||||||
exit: 0,
|
|
||||||
|
|
||||||
replyCh: replyCh,
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.runID, err = util.RandID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go svr.loopCheck()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *VirtualService) Run(ctx context.Context) (err error) {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
svr.ctx = xlog.NewContext(ctx, xlog.New())
|
|
||||||
svr.cancel = cancel
|
|
||||||
|
|
||||||
remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{
|
|
||||||
ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name,
|
|
||||||
ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type,
|
|
||||||
RemotePort: svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("run a reverse proxy on port: %v", remoteAddr)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *VirtualService) Close() {
|
|
||||||
svr.GracefulClose(time.Duration(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *VirtualService) GracefulClose(d time.Duration) {
|
|
||||||
atomic.StoreUint32(&svr.exit, 1)
|
|
||||||
svr.pxy.Close()
|
|
||||||
|
|
||||||
if svr.cancel != nil {
|
|
||||||
svr.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.replyCh <- &VProxyError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *VirtualService) loopCheck() {
|
|
||||||
<-svr.sshSvc.Exit()
|
|
||||||
svr.pxy.Close()
|
|
||||||
log.Info("virtual client service close")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
|
|
||||||
var pxyConf v1.ProxyConfigurer
|
|
||||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, &svr.serverCfg)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// User info
|
|
||||||
userInfo := plugin.UserInfo{
|
|
||||||
User: svr.loginMsg.User,
|
|
||||||
Metas: svr.loginMsg.Metas,
|
|
||||||
RunID: svr.runID,
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.pxy, err = proxy.NewProxy(svr.ctx, &proxy.Options{
|
|
||||||
LoginMsg: svr.loginMsg,
|
|
||||||
UserInfo: userInfo,
|
|
||||||
Configurer: pxyConf,
|
|
||||||
ResourceController: svr.rc,
|
|
||||||
|
|
||||||
GetWorkConnFn: svr.GetWorkConn,
|
|
||||||
PoolCount: 10,
|
|
||||||
|
|
||||||
ServerCfg: &svr.serverCfg,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return remoteAddr, err
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr, err = svr.pxy.Run()
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("proxy run error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("proxy close")
|
|
||||||
svr.pxy.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) {
|
|
||||||
// tell ssh client open a new stream for work
|
|
||||||
payload := forwardedTCPPayload{
|
|
||||||
Addr: svr.serverCfg.BindAddr, // TODO refine
|
|
||||||
Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort),
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("open ssh channel error: %v", err)
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
|
|
||||||
workConn = frp_net.WrapReadWriteCloserToConn(channel, svr.sshSvc.tcpConn)
|
|
||||||
return workConn, nil
|
|
||||||
}
|
|
|
@ -128,3 +128,15 @@ func NewClientTLSConfig(certPath, keyPath, caPath, serverName string) (*tls.Conf
|
||||||
|
|
||||||
return base, nil
|
return base, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewRandomPrivateKey() ([]byte, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
return keyPEM, nil
|
||||||
|
}
|
||||||
|
|
|
@ -15,40 +15,81 @@
|
||||||
package xlog
|
package xlog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type LogPrefix struct {
|
||||||
|
// Name is the name of the prefix, it won't be displayed in log but used to identify the prefix.
|
||||||
|
Name string
|
||||||
|
// Value is the value of the prefix, it will be displayed in log.
|
||||||
|
Value string
|
||||||
|
// The prefix with higher priority will be displayed first, default is 10.
|
||||||
|
Priority int
|
||||||
|
}
|
||||||
|
|
||||||
// Logger is not thread safety for operations on prefix
|
// Logger is not thread safety for operations on prefix
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
prefixes []string
|
prefixes []LogPrefix
|
||||||
|
|
||||||
prefixString string
|
prefixString string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() *Logger {
|
func New() *Logger {
|
||||||
return &Logger{
|
return &Logger{
|
||||||
prefixes: make([]string, 0),
|
prefixes: make([]LogPrefix, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) ResetPrefixes() (old []string) {
|
func (l *Logger) ResetPrefixes() (old []LogPrefix) {
|
||||||
old = l.prefixes
|
old = l.prefixes
|
||||||
l.prefixes = make([]string, 0)
|
l.prefixes = make([]LogPrefix, 0)
|
||||||
l.prefixString = ""
|
l.prefixString = ""
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) AppendPrefix(prefix string) *Logger {
|
func (l *Logger) AppendPrefix(prefix string) *Logger {
|
||||||
l.prefixes = append(l.prefixes, prefix)
|
return l.AddPrefix(LogPrefix{
|
||||||
l.prefixString += "[" + prefix + "] "
|
Name: prefix,
|
||||||
|
Value: prefix,
|
||||||
|
Priority: 10,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) AddPrefix(prefix LogPrefix) *Logger {
|
||||||
|
found := false
|
||||||
|
if prefix.Priority <= 0 {
|
||||||
|
prefix.Priority = 10
|
||||||
|
}
|
||||||
|
for _, p := range l.prefixes {
|
||||||
|
if p.Name == prefix.Name {
|
||||||
|
found = true
|
||||||
|
p.Value = prefix.Value
|
||||||
|
p.Priority = prefix.Priority
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
l.prefixes = append(l.prefixes, prefix)
|
||||||
|
}
|
||||||
|
l.renderPrefixString()
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Logger) renderPrefixString() {
|
||||||
|
sort.SliceStable(l.prefixes, func(i, j int) bool {
|
||||||
|
return l.prefixes[i].Priority < l.prefixes[j].Priority
|
||||||
|
})
|
||||||
|
l.prefixString = ""
|
||||||
|
for _, v := range l.prefixes {
|
||||||
|
l.prefixString += "[" + v.Value + "] "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Logger) Spawn() *Logger {
|
func (l *Logger) Spawn() *Logger {
|
||||||
nl := New()
|
nl := New()
|
||||||
for _, v := range l.prefixes {
|
nl.prefixes = append(nl.prefixes, l.prefixes...)
|
||||||
nl.AppendPrefix(v)
|
nl.renderPrefixString()
|
||||||
}
|
|
||||||
return nl
|
return nl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
// Copyright 2023 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package virtual
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/client"
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
l *utilnet.InternalListener
|
||||||
|
svr *client.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(cfg *v1.ClientCommonConfig) *Client {
|
||||||
|
cfg.Complete()
|
||||||
|
|
||||||
|
ln := utilnet.NewInternalListener()
|
||||||
|
|
||||||
|
svr := client.NewService(cfg, nil, nil, "")
|
||||||
|
svr.SetConnectorCreator(func(context.Context, *v1.ClientCommonConfig) client.Connector {
|
||||||
|
return &pipeConnector{
|
||||||
|
peerListener: ln,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
l: ln,
|
||||||
|
svr: svr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) PeerListener() net.Listener {
|
||||||
|
return c.l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
||||||
|
c.svr.SetInWorkConnCallback(cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) UpdateProxyConfigurer(proxyCfgs []v1.ProxyConfigurer) {
|
||||||
|
_ = c.svr.ReloadConf(proxyCfgs, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Run(ctx context.Context) error {
|
||||||
|
return c.svr.Run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() {
|
||||||
|
c.l.Close()
|
||||||
|
c.svr.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipeConnector struct {
|
||||||
|
peerListener *utilnet.InternalListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *pipeConnector) Open() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *pipeConnector) Connect() (net.Conn, error) {
|
||||||
|
c1, c2 := net.Pipe()
|
||||||
|
if err := pc.peerListener.PutConn(c1); err != nil {
|
||||||
|
c1.Close()
|
||||||
|
c2.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *pipeConnector) Close() error {
|
||||||
|
pc.peerListener.Close()
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -230,14 +229,8 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var workConn net.Conn
|
|
||||||
|
|
||||||
// try all connections from the pool
|
// try all connections from the pool
|
||||||
if strings.HasPrefix(pxy.GetLoginMsg().User, v1.SSHClientLoginUserPrefix) {
|
workConn, err := pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr())
|
||||||
workConn, err = pxy.getWorkConnFn()
|
|
||||||
} else {
|
|
||||||
workConn, err = pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr())
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,13 +18,10 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -32,7 +29,6 @@ import (
|
||||||
fmux "github.com/hashicorp/yamux"
|
fmux "github.com/hashicorp/yamux"
|
||||||
quic "github.com/quic-go/quic-go"
|
quic "github.com/quic-go/quic-go"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/fatedier/frp/assets"
|
"github.com/fatedier/frp/assets"
|
||||||
"github.com/fatedier/frp/pkg/auth"
|
"github.com/fatedier/frp/pkg/auth"
|
||||||
|
@ -41,7 +37,7 @@ import (
|
||||||
"github.com/fatedier/frp/pkg/msg"
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
"github.com/fatedier/frp/pkg/nathole"
|
"github.com/fatedier/frp/pkg/nathole"
|
||||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||||
frpssh "github.com/fatedier/frp/pkg/ssh"
|
"github.com/fatedier/frp/pkg/ssh"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
@ -71,10 +67,6 @@ type Service struct {
|
||||||
// Accept connections from client
|
// Accept connections from client
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
|
||||||
// Accept connections using ssh
|
|
||||||
sshListener net.Listener
|
|
||||||
sshConfig *ssh.ServerConfig
|
|
||||||
|
|
||||||
// Accept connections using kcp
|
// Accept connections using kcp
|
||||||
kcpListener net.Listener
|
kcpListener net.Listener
|
||||||
|
|
||||||
|
@ -87,6 +79,8 @@ type Service struct {
|
||||||
// Accept frp tls connections
|
// Accept frp tls connections
|
||||||
tlsListener net.Listener
|
tlsListener net.Listener
|
||||||
|
|
||||||
|
virtualListener *utilnet.InternalListener
|
||||||
|
|
||||||
// Manage all controllers
|
// Manage all controllers
|
||||||
ctlManager *ControlManager
|
ctlManager *ControlManager
|
||||||
|
|
||||||
|
@ -102,6 +96,8 @@ type Service struct {
|
||||||
// All resource managers and controllers
|
// All resource managers and controllers
|
||||||
rc *controller.ResourceController
|
rc *controller.ResourceController
|
||||||
|
|
||||||
|
sshTunnelGateway *ssh.Gateway
|
||||||
|
|
||||||
// Verifies authentication based on selected method
|
// Verifies authentication based on selected method
|
||||||
authVerifier auth.Verifier
|
authVerifier auth.Verifier
|
||||||
|
|
||||||
|
@ -133,6 +129,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
|
||||||
TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
|
TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
|
||||||
UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
|
UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
|
||||||
},
|
},
|
||||||
|
virtualListener: utilnet.NewInternalListener(),
|
||||||
httpVhostRouter: vhost.NewRouters(),
|
httpVhostRouter: vhost.NewRouters(),
|
||||||
authVerifier: auth.NewAuthVerifier(cfg.Auth),
|
authVerifier: auth.NewAuthVerifier(cfg.Auth),
|
||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
|
@ -208,67 +205,6 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
|
||||||
svr.listener = ln
|
svr.listener = ln
|
||||||
log.Info("frps tcp listen on %s", address)
|
log.Info("frps tcp listen on %s", address)
|
||||||
|
|
||||||
if cfg.SSHTunnelGateway.BindPort > 0 {
|
|
||||||
|
|
||||||
if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" {
|
|
||||||
cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadSSHPublicKeyFilesInDir(cfg.SSHTunnelGateway.PublicKeyFilesPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load ssh all public key files error: %v", err)
|
|
||||||
}
|
|
||||||
log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.sshConfig = &ssh.ServerConfig{
|
|
||||||
NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false),
|
|
||||||
|
|
||||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
||||||
parsedAuthorizedKey, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("cannot find public key file")
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.Type() == parsedAuthorizedKey.Type() && reflect.DeepEqual(parsedAuthorizedKey, key) {
|
|
||||||
return &ssh.Permissions{
|
|
||||||
Extensions: map[string]string{},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unknown public key for %q", conn.User())
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var privateBytes []byte
|
|
||||||
if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" {
|
|
||||||
privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Failed to load private key")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath)
|
|
||||||
} else {
|
|
||||||
privateBytes, err = v1.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Failed to load private key")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Info("auto gen private key file success")
|
|
||||||
}
|
|
||||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Failed to parse private key, error: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.sshConfig.AddHostKey(private)
|
|
||||||
|
|
||||||
sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort))
|
|
||||||
svr.sshListener, err = net.Listen("tcp", sshAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Failed to listen on %v, error: %v", sshAddr, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Info("ssh server listening on %v", sshAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen for accepting connections from client using kcp protocol.
|
// Listen for accepting connections from client using kcp protocol.
|
||||||
if cfg.KCPBindPort > 0 {
|
if cfg.KCPBindPort > 0 {
|
||||||
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort))
|
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort))
|
||||||
|
@ -293,7 +229,17 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
|
||||||
err = fmt.Errorf("listen on quic udp address %s error: %v", address, err)
|
err = fmt.Errorf("listen on quic udp address %s error: %v", address, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Info("frps quic listen on quic %s", address)
|
log.Info("frps quic listen on %s", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.SSHTunnelGateway.BindPort > 0 {
|
||||||
|
sshGateway, err := ssh.NewGateway(cfg.SSHTunnelGateway, cfg.ProxyBindAddr, svr.virtualListener)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("create ssh gateway error: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
svr.sshTunnelGateway = sshGateway
|
||||||
|
log.Info("frps sshTunnelGateway listen on port %d", cfg.SSHTunnelGateway.BindPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen for accepting connections from client using websocket protocol.
|
// Listen for accepting connections from client using websocket protocol.
|
||||||
|
@ -396,23 +342,26 @@ func (svr *Service) Run(ctx context.Context) {
|
||||||
svr.ctx = ctx
|
svr.ctx = ctx
|
||||||
svr.cancel = cancel
|
svr.cancel = cancel
|
||||||
|
|
||||||
if svr.sshListener != nil {
|
go svr.HandleListener(svr.virtualListener, true)
|
||||||
go svr.HandleSSHListener(svr.sshListener)
|
|
||||||
}
|
|
||||||
|
|
||||||
if svr.kcpListener != nil {
|
if svr.kcpListener != nil {
|
||||||
go svr.HandleListener(svr.kcpListener)
|
go svr.HandleListener(svr.kcpListener, false)
|
||||||
}
|
}
|
||||||
if svr.quicListener != nil {
|
if svr.quicListener != nil {
|
||||||
go svr.HandleQUICListener(svr.quicListener)
|
go svr.HandleQUICListener(svr.quicListener)
|
||||||
}
|
}
|
||||||
go svr.HandleListener(svr.websocketListener)
|
go svr.HandleListener(svr.websocketListener, false)
|
||||||
go svr.HandleListener(svr.tlsListener)
|
go svr.HandleListener(svr.tlsListener, false)
|
||||||
|
|
||||||
if svr.rc.NatHoleController != nil {
|
if svr.rc.NatHoleController != nil {
|
||||||
go svr.rc.NatHoleController.CleanWorker(svr.ctx)
|
go svr.rc.NatHoleController.CleanWorker(svr.ctx)
|
||||||
}
|
}
|
||||||
svr.HandleListener(svr.listener)
|
|
||||||
|
if svr.sshTunnelGateway != nil {
|
||||||
|
go svr.sshTunnelGateway.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
svr.HandleListener(svr.listener, false)
|
||||||
|
|
||||||
<-svr.ctx.Done()
|
<-svr.ctx.Done()
|
||||||
// service context may not be canceled by svr.Close(), we should call it here to release resources
|
// service context may not be canceled by svr.Close(), we should call it here to release resources
|
||||||
|
@ -422,10 +371,6 @@ func (svr *Service) Run(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Service) Close() error {
|
func (svr *Service) Close() error {
|
||||||
if svr.sshListener != nil {
|
|
||||||
svr.sshListener.Close()
|
|
||||||
svr.sshListener = nil
|
|
||||||
}
|
|
||||||
if svr.kcpListener != nil {
|
if svr.kcpListener != nil {
|
||||||
svr.kcpListener.Close()
|
svr.kcpListener.Close()
|
||||||
svr.kcpListener = nil
|
svr.kcpListener = nil
|
||||||
|
@ -516,7 +461,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Service) HandleListener(l net.Listener) {
|
func (svr *Service) HandleListener(l net.Listener, internal bool) {
|
||||||
// Listen for incoming connections from client.
|
// Listen for incoming connections from client.
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
|
@ -532,8 +477,9 @@ func (svr *Service) HandleListener(l net.Listener) {
|
||||||
|
|
||||||
log.Trace("start check TLS connection...")
|
log.Trace("start check TLS connection...")
|
||||||
originConn := c
|
originConn := c
|
||||||
|
forceTLS := svr.cfg.Transport.TLS.Force && !internal
|
||||||
var isTLS, custom bool
|
var isTLS, custom bool
|
||||||
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.Transport.TLS.Force, connReadTimeout)
|
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
|
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
|
||||||
originConn.Close()
|
originConn.Close()
|
||||||
|
@ -543,7 +489,7 @@ func (svr *Service) HandleListener(l net.Listener) {
|
||||||
|
|
||||||
// Start a new goroutine to handle connection.
|
// Start a new goroutine to handle connection.
|
||||||
go func(ctx context.Context, frpConn net.Conn) {
|
go func(ctx context.Context, frpConn net.Conn) {
|
||||||
if lo.FromPtr(svr.cfg.Transport.TCPMux) {
|
if lo.FromPtr(svr.cfg.Transport.TCPMux) && !internal {
|
||||||
fmuxCfg := fmux.DefaultConfig()
|
fmuxCfg := fmux.DefaultConfig()
|
||||||
fmuxCfg.KeepAliveInterval = time.Duration(svr.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
fmuxCfg.KeepAliveInterval = time.Duration(svr.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
||||||
fmuxCfg.LogOutput = io.Discard
|
fmuxCfg.LogOutput = io.Discard
|
||||||
|
@ -571,52 +517,6 @@ func (svr *Service) HandleListener(l net.Listener) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Service) HandleSSHListener(listener net.Listener) {
|
|
||||||
for {
|
|
||||||
tcpConn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("failed to accept incoming ssh connection (%s)", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Info("new tcp conn connected: %v", tcpConn.RemoteAddr().String())
|
|
||||||
|
|
||||||
pxyPayloadCh := make(chan v1.ProxyConfigurer)
|
|
||||||
replyCh := make(chan interface{})
|
|
||||||
|
|
||||||
ss, err := frpssh.NewSSHService(tcpConn, svr.sshConfig, pxyPayloadCh, replyCh)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("new ssh service error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ss.Run()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
pxyCfg := <-pxyPayloadCh
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// TODO fill client common config and login msg
|
|
||||||
vs, err := frpssh.NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg,
|
|
||||||
msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()},
|
|
||||||
svr.rc, pxyCfg, ss, replyCh)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("new virtual service error: %v", err)
|
|
||||||
ss.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = vs.Run(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("proxy run error: %v", err)
|
|
||||||
vs.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (svr *Service) HandleQUICListener(l *quic.Listener) {
|
func (svr *Service) HandleQUICListener(l *quic.Listener) {
|
||||||
// Listen for incoming connections from client.
|
// Listen for incoming connections from client.
|
||||||
for {
|
for {
|
||||||
|
|
Loading…
Reference in New Issue