diff --git a/README.md b/README.md index 9ca3a4d5..c65d7fa2 100644 --- a/README.md +++ b/README.md @@ -461,6 +461,8 @@ Config `tls_enable = true` in the `[common]` section to `frpc.ini` to enable thi For port multiplexing, frp sends a first byte `0x17` to dial a TLS connection. +To enforce `frps` to only accept TLS connections - configure `tls_only = true` in the `[common]` section in `frps.ini`. + ### Hot-Reloading frpc configuration The `admin_addr` and `admin_port` fields are required for enabling HTTP API: diff --git a/models/config/server_common.go b/models/config/server_common.go index d4faf8c3..6ad2688d 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -131,6 +131,9 @@ type ServerCommonConf struct { // may proxy to. If this value is 0, no limit will be applied. By default, // this value is 0. MaxPortsPerClient int64 `json:"max_ports_per_client"` + // TlsOnly specifies whether to only accept TLS-encrypted connections. By + // default, the value is false. + TlsOnly bool `json:"tls_only"` // HeartBeatTimeout specifies the maximum time to wait for a heartbeat // before terminating the connection. It is not recommended to change this // value. By default, this value is 90. @@ -171,6 +174,7 @@ func GetDefaultServerConf() ServerCommonConf { AllowPorts: make(map[int]struct{}), MaxPoolCount: 5, MaxPortsPerClient: 0, + TlsOnly: false, HeartBeatTimeout: 90, UserConnTimeout: 10, Custom404Page: "", @@ -388,6 +392,12 @@ func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error cfg.HeartBeatTimeout = v } } + + if tmpStr, ok = conf.Get("common", "tls_only"); ok && tmpStr == "true" { + cfg.TlsOnly = true + } else { + cfg.TlsOnly = false + } return } diff --git a/server/service.go b/server/service.go index dd29bbc8..50ebe0fc 100644 --- a/server/service.go +++ b/server/service.go @@ -284,7 +284,7 @@ func (svr *Service) HandleListener(l net.Listener) { log.Trace("start check TLS connection...") originConn := c - c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, connReadTimeout) + c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TlsOnly, connReadTimeout) if err != nil { log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err) originConn.Close() diff --git a/tests/ci/tls_test.go b/tests/ci/tls_test.go index d2ad8013..c46e9fa7 100644 --- a/tests/ci/tls_test.go +++ b/tests/ci/tls_test.go @@ -186,3 +186,95 @@ func TestTLSOverWebsocket(t *testing.T) { assert.NoError(err) assert.Equal(consts.TEST_TCP_ECHO_STR, res) } + +const FRPS_TLS_ONLY_TCP_CONF = ` +[common] +bind_addr = 0.0.0.0 +bind_port = 20000 +log_file = console +log_level = debug +token = 123456 +tls_only = true +` + +const FRPC_TLS_ONLY_TCP_CONF = ` +[common] +server_addr = 127.0.0.1 +server_port = 20000 +log_file = console +log_level = debug +token = 123456 +protocol = tcp +tls_enable = true + +[tcp] +type = tcp +local_port = 10701 +remote_port = 20801 +` + +const FRPC_TLS_ONLY_NO_TLS_TCP_CONF = ` +[common] +server_addr = 127.0.0.1 +server_port = 20000 +log_file = console +log_level = debug +token = 123456 +protocol = tcp +tls_enable = false + +[tcp] +type = tcp +local_port = 10701 +remote_port = 20802 +` + +func TestTlsOnlyOverTCP(t *testing.T) { + assert := assert.New(t) + frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_ONLY_TCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpsCfgPath) + } + + frpcWithTlsCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_ONLY_TCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpcWithTlsCfgPath) + } + + frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath}) + err = frpsProcess.Start() + if assert.NoError(err) { + defer frpsProcess.Stop() + } + + time.Sleep(200 * time.Millisecond) + + frpcProcessWithTls := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcWithTlsCfgPath}) + err = frpcProcessWithTls.Start() + if assert.NoError(err) { + defer frpcProcessWithTls.Stop() + } + time.Sleep(500 * time.Millisecond) + + // test tcp over tls + res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR) + assert.NoError(err) + assert.Equal(consts.TEST_TCP_ECHO_STR, res) + frpcProcessWithTls.Stop() + + frpcWithoutTlsCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_ONLY_NO_TLS_TCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpcWithTlsCfgPath) + } + + frpcProcessWithoutTls := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcWithoutTlsCfgPath}) + err = frpcProcessWithoutTls.Start() + if assert.NoError(err) { + defer frpcProcessWithoutTls.Stop() + } + time.Sleep(500 * time.Millisecond) + + // test tcp without tls + _, err = util.SendTcpMsg("127.0.0.1:20802", consts.TEST_TCP_ECHO_STR) + assert.Error(err) +} diff --git a/utils/net/tls.go b/utils/net/tls.go index b9fca317..d3271226 100644 --- a/utils/net/tls.go +++ b/utils/net/tls.go @@ -16,6 +16,7 @@ package net import ( "crypto/tls" + "fmt" "net" "time" @@ -32,7 +33,7 @@ func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) { return } -func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, timeout time.Duration) (out net.Conn, err error) { +func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration) (out net.Conn, err error) { sc, r := gnet.NewSharedConnSize(c, 2) buf := make([]byte, 1) var n int @@ -46,6 +47,10 @@ func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, t if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE { out = tls.Server(c, tlsConfig) } else { + if tlsOnly { + err = fmt.Errorf("non-TLS connection received on a TlsOnly server") + return + } out = sc } return