mirror of https://github.com/fatedier/frp
in the stars
parent
f01a3c6c75
commit
1f59480ee9
|
@ -1,17 +1,3 @@
|
|||
// Copyright 2025 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vnet
|
||||
|
||||
import (
|
||||
|
@ -32,18 +18,18 @@ import (
|
|||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
const (
|
||||
maxPacketSize = 1420
|
||||
)
|
||||
const maxPacketSize = 1420
|
||||
|
||||
// Controller manages the virtual network TUN interface and routes packets
|
||||
type Controller struct {
|
||||
addr string
|
||||
|
||||
tun io.ReadWriteCloser
|
||||
clientRouter *clientRouter // Route based on destination IP (client mode)
|
||||
serverRouter *serverRouter // Route based on source IP (server mode)
|
||||
clientRouter *clientRouter // routes packets by destination IP (client mode)
|
||||
serverRouter *serverRouter // routes packets by source IP (server mode)
|
||||
}
|
||||
|
||||
// NewController creates a new Controller based on the provided configuration.
|
||||
func NewController(cfg v1.VirtualNetConfig) *Controller {
|
||||
return &Controller{
|
||||
addr: cfg.Address,
|
||||
|
@ -52,6 +38,7 @@ func NewController(cfg v1.VirtualNetConfig) *Controller {
|
|||
}
|
||||
}
|
||||
|
||||
// Init opens the TUN device with the configured address.
|
||||
func (c *Controller) Init() error {
|
||||
tunDevice, err := OpenTun(context.Background(), c.addr)
|
||||
if err != nil {
|
||||
|
@ -61,9 +48,9 @@ func (c *Controller) Init() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Run continuously reads packets from the TUN device and processes them.
|
||||
func (c *Controller) Run() error {
|
||||
conn := c.tun
|
||||
|
||||
for {
|
||||
buf := pool.GetBuf(maxPacketSize)
|
||||
n, err := conn.Read(buf)
|
||||
|
@ -78,11 +65,12 @@ func (c *Controller) Run() error {
|
|||
}
|
||||
}
|
||||
|
||||
// handlePacket processes a single packet. The caller is responsible for managing the buffer.
|
||||
// handlePacket parses the packet header and forwards it to the appropriate route.
|
||||
func (c *Controller) handlePacket(buf []byte) {
|
||||
log.Tracef("vnet read from tun [%d]: %s", len(buf), base64.StdEncoding.EncodeToString(buf))
|
||||
|
||||
var src, dst net.IP
|
||||
|
||||
switch {
|
||||
case waterutil.IsIPv4(buf):
|
||||
header, err := ipv4.ParseHeader(buf)
|
||||
|
@ -90,37 +78,33 @@ func (c *Controller) handlePacket(buf []byte) {
|
|||
log.Warnf("parse ipv4 header error: %v", err)
|
||||
return
|
||||
}
|
||||
src = header.Src
|
||||
dst = header.Dst
|
||||
log.Tracef("%s >> %s %d/%-4d %-4x %d",
|
||||
header.Src, header.Dst,
|
||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
src, dst = header.Src, header.Dst
|
||||
log.Tracef("%s >> %s %d/%-4d %-4x %d", src, dst, header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
|
||||
case waterutil.IsIPv6(buf):
|
||||
header, err := ipv6.ParseHeader(buf)
|
||||
if err != nil {
|
||||
log.Warnf("parse ipv6 header error: %v", err)
|
||||
return
|
||||
}
|
||||
src = header.Src
|
||||
dst = header.Dst
|
||||
log.Tracef("%s >> %s %d %d",
|
||||
header.Src, header.Dst,
|
||||
header.PayloadLen, header.TrafficClass)
|
||||
src, dst = header.Src, header.Dst
|
||||
log.Tracef("%s >> %s %d %d", src, dst, header.PayloadLen, header.TrafficClass)
|
||||
|
||||
default:
|
||||
log.Tracef("unknown packet, discarded(%d)", len(buf))
|
||||
log.Tracef("unknown packet, discarded (%d bytes)", len(buf))
|
||||
return
|
||||
}
|
||||
|
||||
targetConn, err := c.clientRouter.findConn(dst)
|
||||
if err == nil {
|
||||
// Try client route (based on destination IP)
|
||||
if targetConn, err := c.clientRouter.findConn(dst); err == nil {
|
||||
if err := WriteMessage(targetConn, buf); err != nil {
|
||||
log.Warnf("write to client target conn error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
targetConn, err = c.serverRouter.findConnBySrc(dst)
|
||||
if err == nil {
|
||||
// Try server route (based on source IP)
|
||||
if targetConn, err := c.serverRouter.findConnBySrc(dst); err == nil {
|
||||
if err := WriteMessage(targetConn, buf); err != nil {
|
||||
log.Warnf("write to server target conn error: %v", err)
|
||||
}
|
||||
|
@ -130,15 +114,15 @@ func (c *Controller) handlePacket(buf []byte) {
|
|||
log.Tracef("no route found for packet from %s to %s", src, dst)
|
||||
}
|
||||
|
||||
// Stop closes the TUN interface.
|
||||
func (c *Controller) Stop() error {
|
||||
return c.tun.Close()
|
||||
}
|
||||
|
||||
// Client connection read loop
|
||||
// readLoopClient reads packets from a client connection and writes them to the TUN device.
|
||||
func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
defer func() {
|
||||
// Remove the route when read loop ends (connection closed)
|
||||
c.clientRouter.removeConnRoute(conn)
|
||||
conn.Close()
|
||||
}()
|
||||
|
@ -149,50 +133,25 @@ func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser
|
|||
xl.Warnf("client read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
case waterutil.IsIPv4(data):
|
||||
header, err := ipv4.ParseHeader(data)
|
||||
if err != nil {
|
||||
xl.Warnf("parse ipv4 header error: %v", err)
|
||||
continue
|
||||
}
|
||||
xl.Tracef("%s >> %s %d/%-4d %-4x %d",
|
||||
header.Src, header.Dst,
|
||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
case waterutil.IsIPv6(data):
|
||||
header, err := ipv6.ParseHeader(data)
|
||||
if err != nil {
|
||||
xl.Warnf("parse ipv6 header error: %v", err)
|
||||
continue
|
||||
}
|
||||
xl.Tracef("%s >> %s %d %d",
|
||||
header.Src, header.Dst,
|
||||
header.PayloadLen, header.TrafficClass)
|
||||
default:
|
||||
xl.Tracef("unknown packet, discarded(%d)", len(data))
|
||||
continue
|
||||
}
|
||||
logPacketHeader(xl, data)
|
||||
|
||||
xl.Tracef("vnet write to tun (client) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
|
||||
_, err = c.tun.Write(data)
|
||||
if err != nil {
|
||||
if _, err := c.tun.Write(data); err != nil {
|
||||
xl.Warnf("client write tun error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Server connection read loop
|
||||
// readLoopServer reads packets from a server connection and writes them to the TUN device,
|
||||
// while maintaining source IP to connection mappings.
|
||||
func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser, onClose func()) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
defer func() {
|
||||
// Clean up all IP mappings associated with this connection when it closes
|
||||
c.serverRouter.cleanupConnIPs(conn)
|
||||
// Call the provided callback upon closure
|
||||
if onClose != nil {
|
||||
onClose()
|
||||
}
|
||||
|
@ -205,56 +164,36 @@ func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser
|
|||
xl.Warnf("server read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Register source IP to connection mapping
|
||||
if waterutil.IsIPv4(data) || waterutil.IsIPv6(data) {
|
||||
var src net.IP
|
||||
if waterutil.IsIPv4(data) {
|
||||
header, err := ipv4.ParseHeader(data)
|
||||
if err == nil {
|
||||
src = header.Src
|
||||
c.serverRouter.registerSrcIP(src, conn)
|
||||
}
|
||||
} else {
|
||||
header, err := ipv6.ParseHeader(data)
|
||||
if err == nil {
|
||||
src = header.Src
|
||||
c.serverRouter.registerSrcIP(src, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
registerSourceIP(c.serverRouter, data)
|
||||
|
||||
xl.Tracef("vnet write to tun (server) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
|
||||
_, err = c.tun.Write(data)
|
||||
if err != nil {
|
||||
if _, err := c.tun.Write(data); err != nil {
|
||||
xl.Warnf("server write tun error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClientRoute registers a client route (based on destination IP CIDR)
|
||||
// and starts the read loop
|
||||
// RegisterClientRoute adds a client route and starts the read loop for the connection.
|
||||
func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) {
|
||||
c.clientRouter.addRoute(name, routes, conn)
|
||||
go c.readLoopClient(ctx, conn)
|
||||
}
|
||||
|
||||
// UnregisterClientRoute Remove client route from routing table
|
||||
// UnregisterClientRoute removes a client route.
|
||||
func (c *Controller) UnregisterClientRoute(name string) {
|
||||
c.clientRouter.delRoute(name)
|
||||
}
|
||||
|
||||
// StartServerConnReadLoop starts the read loop for a server connection
|
||||
// (dynamically associates with source IPs)
|
||||
// StartServerConnReadLoop starts the read loop for a server connection with cleanup on close.
|
||||
func (c *Controller) StartServerConnReadLoop(ctx context.Context, conn io.ReadWriteCloser, onClose func()) {
|
||||
go c.readLoopServer(ctx, conn, onClose)
|
||||
}
|
||||
|
||||
// ParseRoutes Convert route strings to IPNet objects
|
||||
// ParseRoutes parses CIDR route strings into net.IPNet objects.
|
||||
func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
|
||||
routes := make([]net.IPNet, 0, len(routeStrings))
|
||||
for _, r := range routeStrings {
|
||||
|
@ -267,7 +206,38 @@ func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
// Client router (based on destination IP routing)
|
||||
// Helper to log IP packet header information
|
||||
func logPacketHeader(xl xlog.Logger, data []byte) {
|
||||
switch {
|
||||
case waterutil.IsIPv4(data):
|
||||
if header, err := ipv4.ParseHeader(data); err == nil {
|
||||
xl.Tracef("%s >> %s %d/%-4d %-4x %d", header.Src, header.Dst, header.Len, header.TotalLen, header.ID, header.Flags)
|
||||
}
|
||||
case waterutil.IsIPv6(data):
|
||||
if header, err := ipv6.ParseHeader(data); err == nil {
|
||||
xl.Tracef("%s >> %s %d %d", header.Src, header.Dst, header.PayloadLen, header.TrafficClass)
|
||||
}
|
||||
default:
|
||||
xl.Tracef("unknown packet, discarded(%d)", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to register source IP for server router
|
||||
func registerSourceIP(r *serverRouter, data []byte) {
|
||||
if waterutil.IsIPv4(data) {
|
||||
if header, err := ipv4.ParseHeader(data); err == nil {
|
||||
r.registerSrcIP(header.Src, nil) // nil or actual connection if available
|
||||
}
|
||||
} else if waterutil.IsIPv6(data) {
|
||||
if header, err := ipv6.ParseHeader(data); err == nil {
|
||||
r.registerSrcIP(header.Src, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------- clientRouter ------------
|
||||
|
||||
// clientRouter routes packets based on destination IP.
|
||||
type clientRouter struct {
|
||||
routes map[string]*routeElement
|
||||
mu sync.RWMutex
|
||||
|
@ -282,16 +252,13 @@ func newClientRouter() *clientRouter {
|
|||
func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.routes[name] = &routeElement{
|
||||
name: name,
|
||||
routes: routes,
|
||||
conn: conn,
|
||||
}
|
||||
r.routes[name] = &routeElement{name: name, routes: routes, conn: conn}
|
||||
}
|
||||
|
||||
func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, re := range r.routes {
|
||||
for _, route := range re.routes {
|
||||
if route.Contains(dst) {
|
||||
|
@ -299,6 +266,7 @@ func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no route found for destination %s", dst)
|
||||
}
|
||||
|
||||
|
@ -311,6 +279,7 @@ func (r *clientRouter) delRoute(name string) {
|
|||
func (r *clientRouter) removeConnRoute(conn io.Writer) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for name, re := range r.routes {
|
||||
if re.conn == conn {
|
||||
delete(r.routes, name)
|
||||
|
@ -319,9 +288,11 @@ func (r *clientRouter) removeConnRoute(conn io.Writer) {
|
|||
}
|
||||
}
|
||||
|
||||
// Server router (based solely on source IP routing)
|
||||
// ----------- serverRouter ------------
|
||||
|
||||
// serverRouter routes packets based on source IP.
|
||||
type serverRouter struct {
|
||||
srcIPConns map[string]io.Writer // Source IP string to connection mapping
|
||||
srcIPConns map[string]io.Writer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
|
@ -334,6 +305,7 @@ func newServerRouter() *serverRouter {
|
|||
func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
conn, exists := r.srcIPConns[src.String()]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no route found for source %s", src)
|
||||
|
@ -348,16 +320,14 @@ func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
|
|||
existingConn, ok := r.srcIPConns[key]
|
||||
r.mu.RUnlock()
|
||||
|
||||
// If the entry exists and the connection is the same, no need to do anything.
|
||||
if ok && existingConn == conn {
|
||||
return
|
||||
}
|
||||
|
||||
// Acquire write lock to update the map.
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring the write lock to handle potential race conditions.
|
||||
// Double-check after locking to avoid race condition
|
||||
existingConn, ok = r.srcIPConns[key]
|
||||
if ok && existingConn == conn {
|
||||
return
|
||||
|
@ -366,12 +336,10 @@ func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
|
|||
r.srcIPConns[key] = conn
|
||||
}
|
||||
|
||||
// cleanupConnIPs removes all IP mappings associated with the specified connection
|
||||
func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Find and delete all IP mappings pointing to this connection
|
||||
for ip, mappedConn := range r.srcIPConns {
|
||||
if mappedConn == conn {
|
||||
delete(r.srcIPConns, ip)
|
||||
|
@ -379,6 +347,9 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
|
|||
}
|
||||
}
|
||||
|
||||
// ----------- routeElement ------------
|
||||
|
||||
// routeElement associates a route name with IP networks and a connection.
|
||||
type routeElement struct {
|
||||
name string
|
||||
routes []net.IPNet
|
||||
|
|
Loading…
Reference in New Issue