mirror of https://github.com/fatedier/frp
fatedier
1 year ago
34 changed files with 1036 additions and 1255 deletions
@ -0,0 +1,223 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/tls" |
||||
"io" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
libdial "github.com/fatedier/golib/net/dial" |
||||
fmux "github.com/hashicorp/yamux" |
||||
quic "github.com/quic-go/quic-go" |
||||
"github.com/samber/lo" |
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
"github.com/fatedier/frp/pkg/transport" |
||||
utilnet "github.com/fatedier/frp/pkg/util/net" |
||||
"github.com/fatedier/frp/pkg/util/xlog" |
||||
) |
||||
|
||||
// Connector is a interface for establishing connections to the server.
|
||||
type Connector interface { |
||||
Open() error |
||||
Connect() (net.Conn, error) |
||||
Close() error |
||||
} |
||||
|
||||
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
||||
type defaultConnectorImpl struct { |
||||
ctx context.Context |
||||
cfg *v1.ClientCommonConfig |
||||
|
||||
muxSession *fmux.Session |
||||
quicConn quic.Connection |
||||
} |
||||
|
||||
func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector { |
||||
return &defaultConnectorImpl{ |
||||
ctx: ctx, |
||||
cfg: cfg, |
||||
} |
||||
} |
||||
|
||||
// Open opens a underlying connection to the server.
|
||||
// The underlying connection is either a TCP connection or a QUIC connection.
|
||||
// After the underlying connection is established, you can call Connect() to get a stream.
|
||||
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
|
||||
func (c *defaultConnectorImpl) Open() error { |
||||
xl := xlog.FromContextSafe(c.ctx) |
||||
|
||||
// special for quic
|
||||
if strings.EqualFold(c.cfg.Transport.Protocol, "quic") { |
||||
var tlsConfig *tls.Config |
||||
var err error |
||||
sn := c.cfg.Transport.TLS.ServerName |
||||
if sn == "" { |
||||
sn = c.cfg.ServerAddr |
||||
} |
||||
if lo.FromPtr(c.cfg.Transport.TLS.Enable) { |
||||
tlsConfig, err = transport.NewClientTLSConfig( |
||||
c.cfg.Transport.TLS.CertFile, |
||||
c.cfg.Transport.TLS.KeyFile, |
||||
c.cfg.Transport.TLS.TrustedCaFile, |
||||
sn) |
||||
} else { |
||||
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn) |
||||
} |
||||
if err != nil { |
||||
xl.Warn("fail to build tls configuration, err: %v", err) |
||||
return err |
||||
} |
||||
tlsConfig.NextProtos = []string{"frp"} |
||||
|
||||
conn, err := quic.DialAddr( |
||||
c.ctx, |
||||
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)), |
||||
tlsConfig, &quic.Config{ |
||||
MaxIdleTimeout: time.Duration(c.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second, |
||||
MaxIncomingStreams: int64(c.cfg.Transport.QUIC.MaxIncomingStreams), |
||||
KeepAlivePeriod: time.Duration(c.cfg.Transport.QUIC.KeepalivePeriod) * time.Second, |
||||
}) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
c.quicConn = conn |
||||
return nil |
||||
} |
||||
|
||||
if !lo.FromPtr(c.cfg.Transport.TCPMux) { |
||||
return nil |
||||
} |
||||
|
||||
conn, err := c.realConnect() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
fmuxCfg := fmux.DefaultConfig() |
||||
fmuxCfg.KeepAliveInterval = time.Duration(c.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second |
||||
fmuxCfg.LogOutput = io.Discard |
||||
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024 |
||||
session, err := fmux.Client(conn, fmuxCfg) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
c.muxSession = session |
||||
return nil |
||||
} |
||||
|
||||
// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
|
||||
func (c *defaultConnectorImpl) Connect() (net.Conn, error) { |
||||
if c.quicConn != nil { |
||||
stream, err := c.quicConn.OpenStreamSync(context.Background()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return utilnet.QuicStreamToNetConn(stream, c.quicConn), nil |
||||
} else if c.muxSession != nil { |
||||
stream, err := c.muxSession.OpenStream() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return stream, nil |
||||
} |
||||
|
||||
return c.realConnect() |
||||
} |
||||
|
||||
func (c *defaultConnectorImpl) realConnect() (net.Conn, error) { |
||||
xl := xlog.FromContextSafe(c.ctx) |
||||
var tlsConfig *tls.Config |
||||
var err error |
||||
tlsEnable := lo.FromPtr(c.cfg.Transport.TLS.Enable) |
||||
if c.cfg.Transport.Protocol == "wss" { |
||||
tlsEnable = true |
||||
} |
||||
if tlsEnable { |
||||
sn := c.cfg.Transport.TLS.ServerName |
||||
if sn == "" { |
||||
sn = c.cfg.ServerAddr |
||||
} |
||||
|
||||
tlsConfig, err = transport.NewClientTLSConfig( |
||||
c.cfg.Transport.TLS.CertFile, |
||||
c.cfg.Transport.TLS.KeyFile, |
||||
c.cfg.Transport.TLS.TrustedCaFile, |
||||
sn) |
||||
if err != nil { |
||||
xl.Warn("fail to build tls configuration, err: %v", err) |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
proxyType, addr, auth, err := libdial.ParseProxyURL(c.cfg.Transport.ProxyURL) |
||||
if err != nil { |
||||
xl.Error("fail to parse proxy url") |
||||
return nil, err |
||||
} |
||||
dialOptions := []libdial.DialOption{} |
||||
protocol := c.cfg.Transport.Protocol |
||||
switch protocol { |
||||
case "websocket": |
||||
protocol = "tcp" |
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")})) |
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ |
||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)), |
||||
})) |
||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) |
||||
case "wss": |
||||
protocol = "tcp" |
||||
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig)) |
||||
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110})) |
||||
default: |
||||
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ |
||||
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)), |
||||
})) |
||||
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) |
||||
} |
||||
|
||||
if c.cfg.Transport.ConnectServerLocalIP != "" { |
||||
dialOptions = append(dialOptions, libdial.WithLocalAddr(c.cfg.Transport.ConnectServerLocalIP)) |
||||
} |
||||
dialOptions = append(dialOptions, |
||||
libdial.WithProtocol(protocol), |
||||
libdial.WithTimeout(time.Duration(c.cfg.Transport.DialServerTimeout)*time.Second), |
||||
libdial.WithKeepAlive(time.Duration(c.cfg.Transport.DialServerKeepAlive)*time.Second), |
||||
libdial.WithProxy(proxyType, addr), |
||||
libdial.WithProxyAuth(auth), |
||||
) |
||||
conn, err := libdial.DialContext( |
||||
c.ctx, |
||||
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)), |
||||
dialOptions..., |
||||
) |
||||
return conn, err |
||||
} |
||||
|
||||
func (c *defaultConnectorImpl) Close() error { |
||||
if c.quicConn != nil { |
||||
_ = c.quicConn.CloseWithError(0, "") |
||||
} |
||||
if c.muxSession != nil { |
||||
_ = c.muxSession.Close() |
||||
} |
||||
return nil |
||||
} |
@ -1,110 +0,0 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"strconv" |
||||
|
||||
"github.com/spf13/cobra" |
||||
|
||||
"github.com/fatedier/frp/pkg/config/types" |
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
) |
||||
|
||||
type PortsRangeSliceFlag struct { |
||||
V *[]types.PortsRange |
||||
} |
||||
|
||||
func (f *PortsRangeSliceFlag) String() string { |
||||
if f.V == nil { |
||||
return "" |
||||
} |
||||
return types.PortsRangeSlice(*f.V).String() |
||||
} |
||||
|
||||
func (f *PortsRangeSliceFlag) Set(s string) error { |
||||
slice, err := types.NewPortsRangeSliceFromString(s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
*f.V = slice |
||||
return nil |
||||
} |
||||
|
||||
func (f *PortsRangeSliceFlag) Type() string { |
||||
return "string" |
||||
} |
||||
|
||||
type BoolFuncFlag struct { |
||||
TrueFunc func() |
||||
FalseFunc func() |
||||
|
||||
v bool |
||||
} |
||||
|
||||
func (f *BoolFuncFlag) String() string { |
||||
return strconv.FormatBool(f.v) |
||||
} |
||||
|
||||
func (f *BoolFuncFlag) Set(s string) error { |
||||
f.v = strconv.FormatBool(f.v) == "true" |
||||
|
||||
if !f.v { |
||||
if f.FalseFunc != nil { |
||||
f.FalseFunc() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
if f.TrueFunc != nil { |
||||
f.TrueFunc() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (f *BoolFuncFlag) Type() string { |
||||
return "bool" |
||||
} |
||||
|
||||
func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) { |
||||
cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address") |
||||
cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port") |
||||
cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port") |
||||
cmd.PersistentFlags().StringVarP(&c.ProxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address") |
||||
cmd.PersistentFlags().IntVarP(&c.VhostHTTPPort, "vhost_http_port", "", 0, "vhost http port") |
||||
cmd.PersistentFlags().IntVarP(&c.VhostHTTPSPort, "vhost_https_port", "", 0, "vhost https port") |
||||
cmd.PersistentFlags().Int64VarP(&c.VhostHTTPTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout") |
||||
cmd.PersistentFlags().StringVarP(&c.WebServer.Addr, "dashboard_addr", "", "0.0.0.0", "dashboard address") |
||||
cmd.PersistentFlags().IntVarP(&c.WebServer.Port, "dashboard_port", "", 0, "dashboard port") |
||||
cmd.PersistentFlags().StringVarP(&c.WebServer.User, "dashboard_user", "", "admin", "dashboard user") |
||||
cmd.PersistentFlags().StringVarP(&c.WebServer.Password, "dashboard_pwd", "", "admin", "dashboard password") |
||||
cmd.PersistentFlags().BoolVarP(&c.EnablePrometheus, "enable_prometheus", "", false, "enable prometheus dashboard") |
||||
cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "log file") |
||||
cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level") |
||||
cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log max days") |
||||
cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console") |
||||
cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token") |
||||
cmd.PersistentFlags().StringVarP(&c.SubDomainHost, "subdomain_host", "", "", "subdomain host") |
||||
cmd.PersistentFlags().VarP(&PortsRangeSliceFlag{V: &c.AllowPorts}, "allow_ports", "", "allow ports") |
||||
cmd.PersistentFlags().Int64VarP(&c.MaxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client") |
||||
cmd.PersistentFlags().BoolVarP(&c.Transport.TLS.Force, "tls_only", "", false, "frps tls only") |
||||
|
||||
webServerTLS := v1.TLSConfig{} |
||||
cmd.PersistentFlags().StringVarP(&webServerTLS.CertFile, "dashboard_tls_cert_file", "", "", "dashboard tls cert file") |
||||
cmd.PersistentFlags().StringVarP(&webServerTLS.KeyFile, "dashboard_tls_key_file", "", "", "dashboard tls key file") |
||||
cmd.PersistentFlags().VarP(&BoolFuncFlag{ |
||||
TrueFunc: func() { c.WebServer.TLS = &webServerTLS }, |
||||
}, "dashboard_tls_mode", "", "if enable dashboard tls mode") |
||||
} |
@ -1,72 +0,0 @@
|
||||
package v1 |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"crypto/x509" |
||||
"encoding/pem" |
||||
"errors" |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
"golang.org/x/crypto/ssh" |
||||
) |
||||
|
||||
const ( |
||||
// custom define
|
||||
SSHClientLoginUserPrefix = "_frpc_ssh_client_" |
||||
) |
||||
|
||||
// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format
|
||||
func GeneratePrivateKey() ([]byte, error) { |
||||
privateKey, err := generatePrivateKey() |
||||
if err != nil { |
||||
return nil, errors.New("gen private key error") |
||||
} |
||||
|
||||
privBlock := pem.Block{ |
||||
Type: "RSA PRIVATE KEY", |
||||
Headers: nil, |
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey), |
||||
} |
||||
|
||||
return pem.EncodeToMemory(&privBlock), nil |
||||
} |
||||
|
||||
// generatePrivateKey creates a RSA Private Key of specified byte size
|
||||
func generatePrivateKey() (*rsa.PrivateKey, error) { |
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
err = privateKey.Validate() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return privateKey, nil |
||||
} |
||||
|
||||
func LoadSSHPublicKeyFilesInDir(dirPath string) (map[string]ssh.PublicKey, error) { |
||||
fileMap := make(map[string]ssh.PublicKey) |
||||
files, err := os.ReadDir(dirPath) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
for _, file := range files { |
||||
filePath := filepath.Join(dirPath, file.Name()) |
||||
content, err := os.ReadFile(filePath) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(content) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
fileMap[ssh.FingerprintSHA256(parsedAuthorizedKey)] = parsedAuthorizedKey |
||||
} |
||||
|
||||
return fileMap, nil |
||||
} |
@ -0,0 +1,149 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"golang.org/x/crypto/ssh" |
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
"github.com/fatedier/frp/pkg/transport" |
||||
"github.com/fatedier/frp/pkg/util/log" |
||||
utilnet "github.com/fatedier/frp/pkg/util/net" |
||||
) |
||||
|
||||
type Gateway struct { |
||||
bindPort int |
||||
ln net.Listener |
||||
|
||||
serverPeerListener *utilnet.InternalListener |
||||
|
||||
sshConfig *ssh.ServerConfig |
||||
} |
||||
|
||||
func NewGateway( |
||||
cfg v1.SSHTunnelGateway, bindAddr string, |
||||
serverPeerListener *utilnet.InternalListener, |
||||
) (*Gateway, error) { |
||||
sshConfig := &ssh.ServerConfig{} |
||||
|
||||
// privateKey
|
||||
var ( |
||||
privateKeyBytes []byte |
||||
err error |
||||
) |
||||
if cfg.PrivateKeyFile != "" { |
||||
privateKeyBytes, err = os.ReadFile(cfg.PrivateKeyFile) |
||||
} else { |
||||
if cfg.AutoGenPrivateKeyPath != "" { |
||||
privateKeyBytes, _ = os.ReadFile(cfg.AutoGenPrivateKeyPath) |
||||
} |
||||
if len(privateKeyBytes) == 0 { |
||||
privateKeyBytes, err = transport.NewRandomPrivateKey() |
||||
if err == nil && cfg.AutoGenPrivateKeyPath != "" { |
||||
err = os.WriteFile(cfg.AutoGenPrivateKeyPath, privateKeyBytes, 0o600) |
||||
} |
||||
} |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
sshConfig.AddHostKey(privateKey) |
||||
|
||||
sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
||||
if cfg.AuthorizedKeysFile == "" { |
||||
return &ssh.Permissions{ |
||||
Extensions: map[string]string{ |
||||
"user": "", |
||||
}, |
||||
}, nil |
||||
} |
||||
|
||||
authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("internal error") |
||||
} |
||||
|
||||
user, ok := authorizedKeysMap[string(key.Marshal())] |
||||
if !ok { |
||||
return nil, fmt.Errorf("unknown public key for remoteAddr %q", conn.RemoteAddr()) |
||||
} |
||||
return &ssh.Permissions{ |
||||
Extensions: map[string]string{ |
||||
"user": user, |
||||
}, |
||||
}, nil |
||||
} |
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(cfg.BindPort))) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &Gateway{ |
||||
bindPort: cfg.BindPort, |
||||
ln: ln, |
||||
serverPeerListener: serverPeerListener, |
||||
sshConfig: sshConfig, |
||||
}, nil |
||||
} |
||||
|
||||
func (g *Gateway) Run() { |
||||
for { |
||||
conn, err := g.ln.Accept() |
||||
if err != nil { |
||||
return |
||||
} |
||||
go g.handleConn(conn) |
||||
} |
||||
} |
||||
|
||||
func (g *Gateway) handleConn(conn net.Conn) { |
||||
defer conn.Close() |
||||
|
||||
ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if err := ts.Run(); err != nil { |
||||
log.Error("ssh tunnel server run error: %v", err) |
||||
} |
||||
} |
||||
|
||||
func loadAuthorizedKeysFromFile(path string) (map[string]string, error) { |
||||
authorizedKeysMap := make(map[string]string) // value is username
|
||||
authorizedKeysBytes, err := os.ReadFile(path) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
for len(authorizedKeysBytes) > 0 { |
||||
pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
authorizedKeysMap[string(pubKey.Marshal())] = strings.TrimSpace(comment) |
||||
authorizedKeysBytes = rest |
||||
} |
||||
return authorizedKeysMap, nil |
||||
} |
@ -0,0 +1,279 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"net" |
||||
"strings" |
||||
"time" |
||||
|
||||
libio "github.com/fatedier/golib/io" |
||||
"github.com/samber/lo" |
||||
"github.com/spf13/cobra" |
||||
"golang.org/x/crypto/ssh" |
||||
|
||||
"github.com/fatedier/frp/pkg/config" |
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
"github.com/fatedier/frp/pkg/msg" |
||||
utilnet "github.com/fatedier/frp/pkg/util/net" |
||||
"github.com/fatedier/frp/pkg/util/util" |
||||
"github.com/fatedier/frp/pkg/util/xlog" |
||||
"github.com/fatedier/frp/pkg/virtual" |
||||
) |
||||
|
||||
const ( |
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
ChannelTypeServerOpenChannel = "forwarded-tcpip" |
||||
RequestTypeForward = "tcpip-forward" |
||||
) |
||||
|
||||
type tcpipForward struct { |
||||
Host string |
||||
Port uint32 |
||||
} |
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
type forwardedTCPPayload struct { |
||||
Addr string |
||||
Port uint32 |
||||
|
||||
// can be default empty value but do not delete it
|
||||
// because ssh protocol shoule be reserved
|
||||
OriginAddr string |
||||
OriginPort uint32 |
||||
} |
||||
|
||||
type TunnelServer struct { |
||||
underlyingConn net.Conn |
||||
sshConn *ssh.ServerConn |
||||
sc *ssh.ServerConfig |
||||
|
||||
vc *virtual.Client |
||||
serverPeerListener *utilnet.InternalListener |
||||
doneCh chan struct{} |
||||
} |
||||
|
||||
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) { |
||||
s := &TunnelServer{ |
||||
underlyingConn: conn, |
||||
sc: sc, |
||||
serverPeerListener: serverPeerListener, |
||||
doneCh: make(chan struct{}), |
||||
} |
||||
return s, nil |
||||
} |
||||
|
||||
func (s *TunnelServer) Run() error { |
||||
sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
s.sshConn = sshConn |
||||
|
||||
addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User) |
||||
pc.Complete(clientCfg.User) |
||||
|
||||
s.vc = virtual.NewClient(clientCfg) |
||||
// join workConn and ssh channel
|
||||
s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool { |
||||
c, err := s.openConn(addr) |
||||
if err != nil { |
||||
return false |
||||
} |
||||
libio.Join(c, workConn) |
||||
return false |
||||
}) |
||||
// transfer connection from virtual client to server peer listener
|
||||
go func() { |
||||
l := s.vc.PeerListener() |
||||
for { |
||||
conn, err := l.Accept() |
||||
if err != nil { |
||||
return |
||||
} |
||||
_ = s.serverPeerListener.PutConn(conn) |
||||
} |
||||
}() |
||||
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100}) |
||||
ctx := xlog.NewContext(context.Background(), xl) |
||||
go func() { |
||||
_ = s.vc.Run(ctx) |
||||
}() |
||||
|
||||
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc}) |
||||
|
||||
_ = sshConn.Wait() |
||||
_ = sshConn.Close() |
||||
s.vc.Close() |
||||
close(s.doneCh) |
||||
return nil |
||||
} |
||||
|
||||
func (s *TunnelServer) waitForwardAddrAndExtraPayload( |
||||
channels <-chan ssh.NewChannel, |
||||
requests <-chan *ssh.Request, |
||||
timeout time.Duration, |
||||
) (*tcpipForward, string, error) { |
||||
addrCh := make(chan *tcpipForward, 1) |
||||
extraPayloadCh := make(chan string, 1) |
||||
|
||||
// get forward address
|
||||
go func() { |
||||
addrGot := false |
||||
for req := range requests { |
||||
switch req.Type { |
||||
case RequestTypeForward: |
||||
if !addrGot { |
||||
payload := tcpipForward{} |
||||
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
||||
return |
||||
} |
||||
addrGot = true |
||||
addrCh <- &payload |
||||
} |
||||
default: |
||||
if req.WantReply { |
||||
_ = req.Reply(true, nil) |
||||
} |
||||
} |
||||
} |
||||
}() |
||||
|
||||
// get extra payload
|
||||
go func() { |
||||
for newChannel := range channels { |
||||
// extraPayload will send to extraPayloadCh
|
||||
go s.handleNewChannel(newChannel, extraPayloadCh) |
||||
} |
||||
}() |
||||
|
||||
var ( |
||||
addr *tcpipForward |
||||
extraPayload string |
||||
) |
||||
|
||||
timer := time.NewTimer(timeout) |
||||
defer timer.Stop() |
||||
for { |
||||
select { |
||||
case v := <-addrCh: |
||||
addr = v |
||||
case extra := <-extraPayloadCh: |
||||
extraPayload = extra |
||||
case <-timer.C: |
||||
return nil, "", fmt.Errorf("get addr and extra payload timeout") |
||||
} |
||||
if addr != nil && extraPayload != "" { |
||||
break |
||||
} |
||||
} |
||||
return addr, extraPayload, nil |
||||
} |
||||
|
||||
func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) { |
||||
cmd := &cobra.Command{} |
||||
args := strings.Split(extraPayload, " ") |
||||
if len(args) < 1 { |
||||
return nil, nil, fmt.Errorf("invalid extra payload") |
||||
} |
||||
proxyType := strings.TrimSpace(args[0]) |
||||
supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"} |
||||
if !lo.Contains(supportTypes, proxyType) { |
||||
return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes) |
||||
} |
||||
pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType)) |
||||
if pc == nil { |
||||
return nil, nil, fmt.Errorf("new proxy configurer error") |
||||
} |
||||
config.RegisterProxyFlags(cmd, pc) |
||||
|
||||
clientCfg := v1.ClientCommonConfig{} |
||||
config.RegisterClientCommonConfigFlags(cmd, &clientCfg) |
||||
|
||||
if err := cmd.ParseFlags(args); err != nil { |
||||
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err) |
||||
} |
||||
return &clientCfg, pc, nil |
||||
} |
||||
|
||||
func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) { |
||||
ch, reqs, err := channel.Accept() |
||||
if err != nil { |
||||
return |
||||
} |
||||
go s.keepAlive(ch) |
||||
|
||||
for req := range reqs { |
||||
if req.Type != "exec" { |
||||
continue |
||||
} |
||||
if len(req.Payload) <= 4 { |
||||
continue |
||||
} |
||||
end := 4 + binary.BigEndian.Uint32(req.Payload[:4]) |
||||
if len(req.Payload) < int(end) { |
||||
continue |
||||
} |
||||
extraPayload := string(req.Payload[4:end]) |
||||
select { |
||||
case extraPayloadCh <- extraPayload: |
||||
default: |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (s *TunnelServer) keepAlive(ch ssh.Channel) { |
||||
tk := time.NewTicker(time.Second * 30) |
||||
defer tk.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-tk.C: |
||||
_, err := ch.SendRequest("heartbeat", false, nil) |
||||
if err != nil { |
||||
return |
||||
} |
||||
case <-s.doneCh: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) { |
||||
payload := forwardedTCPPayload{ |
||||
Addr: addr.Host, |
||||
Port: addr.Port, |
||||
} |
||||
channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload)) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("open ssh channel error: %v", err) |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
|
||||
conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn) |
||||
return conn, nil |
||||
} |
@ -1,497 +0,0 @@
|
||||
package ssh |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"errors" |
||||
"flag" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
gerror "github.com/fatedier/golib/errors" |
||||
"golang.org/x/crypto/ssh" |
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
"github.com/fatedier/frp/pkg/util/log" |
||||
) |
||||
|
||||
const ( |
||||
// ssh protocol define
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
ChannelTypeServerOpenChannel = "forwarded-tcpip" |
||||
RequestTypeForward = "tcpip-forward" |
||||
|
||||
// golang ssh package define.
|
||||
// https://pkg.go.dev/golang.org/x/crypto/ssh
|
||||
RequestTypeHeartbeat = "keepalive@openssh.com" |
||||
) |
||||
|
||||
// 当 proxy 失败会返回该错误
|
||||
type VProxyError struct{} |
||||
|
||||
// ssh protocol define
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#page-16
|
||||
// parse ssh client cmds input
|
||||
type forwardedTCPPayload struct { |
||||
Addr string |
||||
Port uint32 |
||||
|
||||
// can be default empty value but do not delete it
|
||||
// because ssh protocol shoule be reserved
|
||||
OriginAddr string |
||||
OriginPort uint32 |
||||
} |
||||
|
||||
// custom define
|
||||
// parse ssh client cmds input
|
||||
type CmdPayload struct { |
||||
Address string |
||||
Port uint32 |
||||
} |
||||
|
||||
// custom define
|
||||
// with frp control cmds
|
||||
type ExtraPayload struct { |
||||
Type string |
||||
|
||||
// TODO port can be set by extra message and priority to ssh raw cmd
|
||||
Address string |
||||
Port uint32 |
||||
} |
||||
|
||||
type Service struct { |
||||
tcpConn net.Conn |
||||
cfg *ssh.ServerConfig |
||||
|
||||
sshConn *ssh.ServerConn |
||||
gChannel <-chan ssh.NewChannel |
||||
gReq <-chan *ssh.Request |
||||
|
||||
addrPayloadCh chan CmdPayload |
||||
extraPayloadCh chan ExtraPayload |
||||
|
||||
proxyPayloadCh chan v1.ProxyConfigurer |
||||
replyCh chan interface{} |
||||
|
||||
closeCh chan struct{} |
||||
exit int32 |
||||
} |
||||
|
||||
func NewSSHService( |
||||
tcpConn net.Conn, |
||||
cfg *ssh.ServerConfig, |
||||
proxyPayloadCh chan v1.ProxyConfigurer, |
||||
replyCh chan interface{}, |
||||
) (ss *Service, err error) { |
||||
ss = &Service{ |
||||
tcpConn: tcpConn, |
||||
cfg: cfg, |
||||
|
||||
addrPayloadCh: make(chan CmdPayload), |
||||
extraPayloadCh: make(chan ExtraPayload), |
||||
|
||||
proxyPayloadCh: proxyPayloadCh, |
||||
replyCh: replyCh, |
||||
|
||||
closeCh: make(chan struct{}), |
||||
exit: 0, |
||||
} |
||||
|
||||
ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg) |
||||
if err != nil { |
||||
log.Error("ssh handshake error: %v", err) |
||||
return nil, err |
||||
} |
||||
|
||||
log.Info("ssh connection success") |
||||
|
||||
return ss, nil |
||||
} |
||||
|
||||
func (ss *Service) Run() { |
||||
go ss.loopGenerateProxy() |
||||
go ss.loopParseCmdPayload() |
||||
go ss.loopParseExtraPayload() |
||||
go ss.loopReply() |
||||
} |
||||
|
||||
func (ss *Service) Exit() <-chan struct{} { |
||||
return ss.closeCh |
||||
} |
||||
|
||||
func (ss *Service) Close() { |
||||
if atomic.LoadInt32(&ss.exit) == 1 { |
||||
return |
||||
} |
||||
|
||||
select { |
||||
case <-ss.closeCh: |
||||
return |
||||
default: |
||||
} |
||||
|
||||
close(ss.closeCh) |
||||
close(ss.addrPayloadCh) |
||||
close(ss.extraPayloadCh) |
||||
|
||||
_ = ss.sshConn.Wait() |
||||
|
||||
ss.sshConn.Close() |
||||
ss.tcpConn.Close() |
||||
|
||||
atomic.StoreInt32(&ss.exit, 1) |
||||
|
||||
log.Info("ssh service close") |
||||
} |
||||
|
||||
func (ss *Service) loopParseCmdPayload() { |
||||
for { |
||||
select { |
||||
case req, ok := <-ss.gReq: |
||||
if !ok { |
||||
log.Info("global request is close") |
||||
ss.Close() |
||||
return |
||||
} |
||||
|
||||
switch req.Type { |
||||
case RequestTypeForward: |
||||
var addrPayload CmdPayload |
||||
if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil { |
||||
log.Error("ssh unmarshal error: %v", err) |
||||
return |
||||
} |
||||
_ = gerror.PanicToError(func() { |
||||
ss.addrPayloadCh <- addrPayload |
||||
}) |
||||
default: |
||||
if req.Type == RequestTypeHeartbeat { |
||||
log.Debug("ssh heartbeat data") |
||||
} else { |
||||
log.Info("default req, data: %v", req) |
||||
} |
||||
} |
||||
if req.WantReply { |
||||
err := req.Reply(true, nil) |
||||
if err != nil { |
||||
log.Error("reply to ssh client error: %v", err) |
||||
} |
||||
} |
||||
case <-ss.closeCh: |
||||
log.Info("loop parse cmd payload close") |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ss *Service) loopSendHeartbeat(ch ssh.Channel) { |
||||
tk := time.NewTicker(time.Second * 60) |
||||
defer tk.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-tk.C: |
||||
ok, err := ch.SendRequest("heartbeat", false, nil) |
||||
if err != nil { |
||||
log.Error("channel send req error: %v", err) |
||||
if err == io.EOF { |
||||
ss.Close() |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
log.Debug("heartbeat send success, ok: %v", ok) |
||||
case <-ss.closeCh: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ss *Service) loopParseExtraPayload() { |
||||
log.Info("loop parse extra payload start") |
||||
|
||||
for newChannel := range ss.gChannel { |
||||
ch, req, err := newChannel.Accept() |
||||
if err != nil { |
||||
log.Error("channel accept error: %v", err) |
||||
return |
||||
} |
||||
|
||||
go ss.loopSendHeartbeat(ch) |
||||
|
||||
go func(req <-chan *ssh.Request) { |
||||
for r := range req { |
||||
if len(r.Payload) <= 4 { |
||||
log.Info("r.payload is less than 4") |
||||
continue |
||||
} |
||||
if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") { |
||||
log.Info("ssh protocol exchange data") |
||||
continue |
||||
} |
||||
|
||||
// [4byte data_len|data]
|
||||
end := 4 + binary.BigEndian.Uint32(r.Payload[:4]) |
||||
if end > uint32(len(r.Payload)) { |
||||
end = uint32(len(r.Payload)) |
||||
} |
||||
p := string(r.Payload[4:end]) |
||||
|
||||
msg, err := parseSSHExtraMessage(p) |
||||
if err != nil { |
||||
log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload) |
||||
continue |
||||
} |
||||
_ = gerror.PanicToError(func() { |
||||
ss.extraPayloadCh <- msg |
||||
}) |
||||
return |
||||
} |
||||
}(req) |
||||
} |
||||
} |
||||
|
||||
func (ss *Service) SSHConn() *ssh.ServerConn { |
||||
return ss.sshConn |
||||
} |
||||
|
||||
func (ss *Service) TCPConn() net.Conn { |
||||
return ss.tcpConn |
||||
} |
||||
|
||||
func (ss *Service) loopReply() { |
||||
for { |
||||
select { |
||||
case <-ss.closeCh: |
||||
log.Info("loop reply close") |
||||
return |
||||
case req := <-ss.replyCh: |
||||
switch req.(type) { |
||||
case *VProxyError: |
||||
log.Error("run frp proxy error, close ssh service") |
||||
ss.Close() |
||||
default: |
||||
// TODO
|
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (ss *Service) loopGenerateProxy() { |
||||
log.Info("loop generate proxy start") |
||||
|
||||
for { |
||||
if atomic.LoadInt32(&ss.exit) == 1 { |
||||
return |
||||
} |
||||
|
||||
wg := new(sync.WaitGroup) |
||||
wg.Add(2) |
||||
|
||||
var p1 CmdPayload |
||||
var p2 ExtraPayload |
||||
|
||||
go func() { |
||||
defer wg.Done() |
||||
for { |
||||
select { |
||||
case <-ss.closeCh: |
||||
return |
||||
case p1 = <-ss.addrPayloadCh: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
go func() { |
||||
defer wg.Done() |
||||
for { |
||||
select { |
||||
case <-ss.closeCh: |
||||
return |
||||
case p2 = <-ss.extraPayloadCh: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
wg.Wait() |
||||
|
||||
if atomic.LoadInt32(&ss.exit) == 1 { |
||||
return |
||||
} |
||||
|
||||
switch p2.Type { |
||||
case "http": |
||||
case "tcp": |
||||
ss.proxyPayloadCh <- &v1.TCPProxyConfig{ |
||||
ProxyBaseConfig: v1.ProxyBaseConfig{ |
||||
Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()), |
||||
Type: p2.Type, |
||||
|
||||
ProxyBackend: v1.ProxyBackend{ |
||||
LocalIP: p1.Address, |
||||
}, |
||||
}, |
||||
RemotePort: int(p1.Port), |
||||
} |
||||
default: |
||||
log.Warn("invalid frp proxy type: %v", p2.Type) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func parseSSHExtraMessage(s string) (p ExtraPayload, err error) { |
||||
sn := len(s) |
||||
|
||||
log.Info("parse ssh extra message: %v", s) |
||||
|
||||
ss := strings.Fields(s) |
||||
if len(ss) == 0 { |
||||
if sn != 0 { |
||||
ss = append(ss, s) |
||||
} else { |
||||
return p, fmt.Errorf("invalid ssh input, args: %v", ss) |
||||
} |
||||
} |
||||
|
||||
for i, v := range ss { |
||||
ss[i] = strings.TrimSpace(v) |
||||
} |
||||
|
||||
if ss[0] != "tcp" && ss[0] != "http" { |
||||
return p, fmt.Errorf("only support tcp/http now") |
||||
} |
||||
|
||||
switch ss[0] { |
||||
case "tcp": |
||||
tcpCmd, err := ParseTCPCommand(ss) |
||||
if err != nil { |
||||
return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) |
||||
} |
||||
|
||||
port, _ := strconv.Atoi(tcpCmd.Port) |
||||
|
||||
p = ExtraPayload{ |
||||
Type: "tcp", |
||||
Address: tcpCmd.Address, |
||||
Port: uint32(port), |
||||
} |
||||
case "http": |
||||
httpCmd, err := ParseHTTPCommand(ss) |
||||
if err != nil { |
||||
return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) |
||||
} |
||||
|
||||
_ = httpCmd |
||||
|
||||
p = ExtraPayload{ |
||||
Type: "http", |
||||
} |
||||
} |
||||
|
||||
return p, nil |
||||
} |
||||
|
||||
type HTTPCommand struct { |
||||
Domain string |
||||
BasicAuthUser string |
||||
BasicAuthPass string |
||||
} |
||||
|
||||
func ParseHTTPCommand(params []string) (*HTTPCommand, error) { |
||||
if len(params) < 2 { |
||||
return nil, errors.New("invalid HTTP command") |
||||
} |
||||
|
||||
var ( |
||||
basicAuth string |
||||
domainURL string |
||||
basicAuthUser string |
||||
basicAuthPass string |
||||
) |
||||
|
||||
fs := flag.NewFlagSet("http", flag.ContinueOnError) |
||||
fs.StringVar(&basicAuth, "basic-auth", "", "") |
||||
fs.StringVar(&domainURL, "domain", "", "") |
||||
|
||||
fs.SetOutput(&nullWriter{}) // Disables usage output
|
||||
|
||||
err := fs.Parse(params[2:]) |
||||
if err != nil { |
||||
if !errors.Is(err, flag.ErrHelp) { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
if basicAuth != "" { |
||||
authParts := strings.SplitN(basicAuth, ":", 2) |
||||
basicAuthUser = authParts[0] |
||||
if len(authParts) > 1 { |
||||
basicAuthPass = authParts[1] |
||||
} |
||||
} |
||||
|
||||
httpCmd := &HTTPCommand{ |
||||
Domain: domainURL, |
||||
BasicAuthUser: basicAuthUser, |
||||
BasicAuthPass: basicAuthPass, |
||||
} |
||||
return httpCmd, nil |
||||
} |
||||
|
||||
type TCPCommand struct { |
||||
Address string |
||||
Port string |
||||
} |
||||
|
||||
func ParseTCPCommand(params []string) (*TCPCommand, error) { |
||||
if len(params) == 0 || params[0] != "tcp" { |
||||
return nil, errors.New("invalid TCP command") |
||||
} |
||||
|
||||
if len(params) == 1 { |
||||
return &TCPCommand{}, nil |
||||
} |
||||
|
||||
var ( |
||||
address string |
||||
port string |
||||
) |
||||
|
||||
fs := flag.NewFlagSet("tcp", flag.ContinueOnError) |
||||
fs.StringVar(&address, "address", "", "The IP address to listen on") |
||||
fs.StringVar(&port, "port", "", "The port to listen on") |
||||
fs.SetOutput(&nullWriter{}) // Disables usage output
|
||||
|
||||
args := params[1:] |
||||
err := fs.Parse(args) |
||||
if err != nil { |
||||
if !errors.Is(err, flag.ErrHelp) { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
parsedAddr, err := net.ResolveIPAddr("ip", address) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if _, err := net.LookupPort("tcp", port); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
tcpCmd := &TCPCommand{ |
||||
Address: parsedAddr.String(), |
||||
Port: port, |
||||
} |
||||
return tcpCmd, nil |
||||
} |
||||
|
||||
type nullWriter struct{} |
||||
|
||||
func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil } |
@ -1,185 +0,0 @@
|
||||
package ssh |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"golang.org/x/crypto/ssh" |
||||
|
||||
"github.com/fatedier/frp/pkg/config" |
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
"github.com/fatedier/frp/pkg/msg" |
||||
plugin "github.com/fatedier/frp/pkg/plugin/server" |
||||
"github.com/fatedier/frp/pkg/util/log" |
||||
frp_net "github.com/fatedier/frp/pkg/util/net" |
||||
"github.com/fatedier/frp/pkg/util/util" |
||||
"github.com/fatedier/frp/pkg/util/xlog" |
||||
"github.com/fatedier/frp/server/controller" |
||||
"github.com/fatedier/frp/server/proxy" |
||||
) |
||||
|
||||
// VirtualService is a client VirtualService run in frps
|
||||
type VirtualService struct { |
||||
clientCfg v1.ClientCommonConfig |
||||
pxyCfg v1.ProxyConfigurer |
||||
serverCfg v1.ServerConfig |
||||
|
||||
sshSvc *Service |
||||
|
||||
// uniq id got from frps, attach it in loginMsg
|
||||
runID string |
||||
loginMsg *msg.Login |
||||
|
||||
// All resource managers and controllers
|
||||
rc *controller.ResourceController |
||||
|
||||
exit uint32 // 0 means not exit
|
||||
// SSHService context
|
||||
ctx context.Context |
||||
// call cancel to stop SSHService
|
||||
cancel context.CancelFunc |
||||
|
||||
replyCh chan interface{} |
||||
pxy proxy.Proxy |
||||
} |
||||
|
||||
func NewVirtualService( |
||||
ctx context.Context, |
||||
clientCfg v1.ClientCommonConfig, |
||||
serverCfg v1.ServerConfig, |
||||
logMsg msg.Login, |
||||
rc *controller.ResourceController, |
||||
pxyCfg v1.ProxyConfigurer, |
||||
sshSvc *Service, |
||||
replyCh chan interface{}, |
||||
) (svr *VirtualService, err error) { |
||||
svr = &VirtualService{ |
||||
clientCfg: clientCfg, |
||||
serverCfg: serverCfg, |
||||
rc: rc, |
||||
|
||||
loginMsg: &logMsg, |
||||
|
||||
sshSvc: sshSvc, |
||||
pxyCfg: pxyCfg, |
||||
|
||||
ctx: ctx, |
||||
exit: 0, |
||||
|
||||
replyCh: replyCh, |
||||
} |
||||
|
||||
svr.runID, err = util.RandID() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
go svr.loopCheck() |
||||
|
||||
return |
||||
} |
||||
|
||||
func (svr *VirtualService) Run(ctx context.Context) (err error) { |
||||
ctx, cancel := context.WithCancel(ctx) |
||||
svr.ctx = xlog.NewContext(ctx, xlog.New()) |
||||
svr.cancel = cancel |
||||
|
||||
remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{ |
||||
ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name, |
||||
ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type, |
||||
RemotePort: svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort, |
||||
}) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Info("run a reverse proxy on port: %v", remoteAddr) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (svr *VirtualService) Close() { |
||||
svr.GracefulClose(time.Duration(0)) |
||||
} |
||||
|
||||
func (svr *VirtualService) GracefulClose(d time.Duration) { |
||||
atomic.StoreUint32(&svr.exit, 1) |
||||
svr.pxy.Close() |
||||
|
||||
if svr.cancel != nil { |
||||
svr.cancel() |
||||
} |
||||
|
||||
svr.replyCh <- &VProxyError{} |
||||
} |
||||
|
||||
func (svr *VirtualService) loopCheck() { |
||||
<-svr.sshSvc.Exit() |
||||
svr.pxy.Close() |
||||
log.Info("virtual client service close") |
||||
} |
||||
|
||||
func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { |
||||
var pxyConf v1.ProxyConfigurer |
||||
pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, &svr.serverCfg) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
// User info
|
||||
userInfo := plugin.UserInfo{ |
||||
User: svr.loginMsg.User, |
||||
Metas: svr.loginMsg.Metas, |
||||
RunID: svr.runID, |
||||
} |
||||
|
||||
svr.pxy, err = proxy.NewProxy(svr.ctx, &proxy.Options{ |
||||
LoginMsg: svr.loginMsg, |
||||
UserInfo: userInfo, |
||||
Configurer: pxyConf, |
||||
ResourceController: svr.rc, |
||||
|
||||
GetWorkConnFn: svr.GetWorkConn, |
||||
PoolCount: 10, |
||||
|
||||
ServerCfg: &svr.serverCfg, |
||||
}) |
||||
if err != nil { |
||||
return remoteAddr, err |
||||
} |
||||
|
||||
remoteAddr, err = svr.pxy.Run() |
||||
if err != nil { |
||||
log.Warn("proxy run error: %v", err) |
||||
return |
||||
} |
||||
|
||||
defer func() { |
||||
if err != nil { |
||||
log.Warn("proxy close") |
||||
svr.pxy.Close() |
||||
} |
||||
}() |
||||
|
||||
return |
||||
} |
||||
|
||||
func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) { |
||||
// tell ssh client open a new stream for work
|
||||
payload := forwardedTCPPayload{ |
||||
Addr: svr.serverCfg.BindAddr, // TODO refine
|
||||
Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort), |
||||
} |
||||
|
||||
channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload)) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("open ssh channel error: %v", err) |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
|
||||
workConn = frp_net.WrapReadWriteCloserToConn(channel, svr.sshSvc.tcpConn) |
||||
return workConn, nil |
||||
} |
@ -0,0 +1,92 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package virtual |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
|
||||
"github.com/fatedier/frp/client" |
||||
v1 "github.com/fatedier/frp/pkg/config/v1" |
||||
"github.com/fatedier/frp/pkg/msg" |
||||
utilnet "github.com/fatedier/frp/pkg/util/net" |
||||
) |
||||
|
||||
type Client struct { |
||||
l *utilnet.InternalListener |
||||
svr *client.Service |
||||
} |
||||
|
||||
func NewClient(cfg *v1.ClientCommonConfig) *Client { |
||||
cfg.Complete() |
||||
|
||||
ln := utilnet.NewInternalListener() |
||||
|
||||
svr := client.NewService(cfg, nil, nil, "") |
||||
svr.SetConnectorCreator(func(context.Context, *v1.ClientCommonConfig) client.Connector { |
||||
return &pipeConnector{ |
||||
peerListener: ln, |
||||
} |
||||
}) |
||||
|
||||
return &Client{ |
||||
l: ln, |
||||
svr: svr, |
||||
} |
||||
} |
||||
|
||||
func (c *Client) PeerListener() net.Listener { |
||||
return c.l |
||||
} |
||||
|
||||
func (c *Client) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { |
||||
c.svr.SetInWorkConnCallback(cb) |
||||
} |
||||
|
||||
func (c *Client) UpdateProxyConfigurer(proxyCfgs []v1.ProxyConfigurer) { |
||||
_ = c.svr.ReloadConf(proxyCfgs, nil) |
||||
} |
||||
|
||||
func (c *Client) Run(ctx context.Context) error { |
||||
return c.svr.Run(ctx) |
||||
} |
||||
|
||||
func (c *Client) Close() { |
||||
c.l.Close() |
||||
c.svr.Close() |
||||
} |
||||
|
||||
type pipeConnector struct { |
||||
peerListener *utilnet.InternalListener |
||||
} |
||||
|
||||
func (pc *pipeConnector) Open() error { |
||||
return nil |
||||
} |
||||
|
||||
func (pc *pipeConnector) Connect() (net.Conn, error) { |
||||
c1, c2 := net.Pipe() |
||||
if err := pc.peerListener.PutConn(c1); err != nil { |
||||
c1.Close() |
||||
c2.Close() |
||||
return nil, err |
||||
} |
||||
return c2, nil |
||||
} |
||||
|
||||
func (pc *pipeConnector) Close() error { |
||||
pc.peerListener.Close() |
||||
return nil |
||||
} |
Loading…
Reference in new issue