From 1f59480ee93aaf258f4c0e40aedf5baa7d59cbc4 Mon Sep 17 00:00:00 2001 From: "Ben.Laskin" Date: Tue, 24 Jun 2025 14:22:53 -0400 Subject: [PATCH] in the stars --- pkg/vnet/controller.go | 183 +++++++++++++++++------------------------ 1 file changed, 77 insertions(+), 106 deletions(-) diff --git a/pkg/vnet/controller.go b/pkg/vnet/controller.go index ca71a8c3..75276407 100644 --- a/pkg/vnet/controller.go +++ b/pkg/vnet/controller.go @@ -1,17 +1,3 @@ -// Copyright 2025 The frp Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package vnet import ( @@ -32,18 +18,18 @@ import ( "github.com/fatedier/frp/pkg/util/xlog" ) -const ( - maxPacketSize = 1420 -) +const maxPacketSize = 1420 +// Controller manages the virtual network TUN interface and routes packets type Controller struct { addr string tun io.ReadWriteCloser - clientRouter *clientRouter // Route based on destination IP (client mode) - serverRouter *serverRouter // Route based on source IP (server mode) + clientRouter *clientRouter // routes packets by destination IP (client mode) + serverRouter *serverRouter // routes packets by source IP (server mode) } +// NewController creates a new Controller based on the provided configuration. func NewController(cfg v1.VirtualNetConfig) *Controller { return &Controller{ addr: cfg.Address, @@ -52,6 +38,7 @@ func NewController(cfg v1.VirtualNetConfig) *Controller { } } +// Init opens the TUN device with the configured address. func (c *Controller) Init() error { tunDevice, err := OpenTun(context.Background(), c.addr) if err != nil { @@ -61,9 +48,9 @@ func (c *Controller) Init() error { return nil } +// Run continuously reads packets from the TUN device and processes them. func (c *Controller) Run() error { conn := c.tun - for { buf := pool.GetBuf(maxPacketSize) n, err := conn.Read(buf) @@ -78,11 +65,12 @@ func (c *Controller) Run() error { } } -// handlePacket processes a single packet. The caller is responsible for managing the buffer. +// handlePacket parses the packet header and forwards it to the appropriate route. func (c *Controller) handlePacket(buf []byte) { log.Tracef("vnet read from tun [%d]: %s", len(buf), base64.StdEncoding.EncodeToString(buf)) var src, dst net.IP + switch { case waterutil.IsIPv4(buf): header, err := ipv4.ParseHeader(buf) @@ -90,37 +78,33 @@ func (c *Controller) handlePacket(buf []byte) { log.Warnf("parse ipv4 header error: %v", err) return } - src = header.Src - dst = header.Dst - log.Tracef("%s >> %s %d/%-4d %-4x %d", - header.Src, header.Dst, - header.Len, header.TotalLen, header.ID, header.Flags) + src, dst = header.Src, header.Dst + log.Tracef("%s >> %s %d/%-4d %-4x %d", src, dst, header.Len, header.TotalLen, header.ID, header.Flags) + case waterutil.IsIPv6(buf): header, err := ipv6.ParseHeader(buf) if err != nil { log.Warnf("parse ipv6 header error: %v", err) return } - src = header.Src - dst = header.Dst - log.Tracef("%s >> %s %d %d", - header.Src, header.Dst, - header.PayloadLen, header.TrafficClass) + src, dst = header.Src, header.Dst + log.Tracef("%s >> %s %d %d", src, dst, header.PayloadLen, header.TrafficClass) + default: - log.Tracef("unknown packet, discarded(%d)", len(buf)) + log.Tracef("unknown packet, discarded (%d bytes)", len(buf)) return } - targetConn, err := c.clientRouter.findConn(dst) - if err == nil { + // Try client route (based on destination IP) + if targetConn, err := c.clientRouter.findConn(dst); err == nil { if err := WriteMessage(targetConn, buf); err != nil { log.Warnf("write to client target conn error: %v", err) } return } - targetConn, err = c.serverRouter.findConnBySrc(dst) - if err == nil { + // Try server route (based on source IP) + if targetConn, err := c.serverRouter.findConnBySrc(dst); err == nil { if err := WriteMessage(targetConn, buf); err != nil { log.Warnf("write to server target conn error: %v", err) } @@ -130,15 +114,15 @@ func (c *Controller) handlePacket(buf []byte) { log.Tracef("no route found for packet from %s to %s", src, dst) } +// Stop closes the TUN interface. func (c *Controller) Stop() error { return c.tun.Close() } -// Client connection read loop +// readLoopClient reads packets from a client connection and writes them to the TUN device. func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser) { xl := xlog.FromContextSafe(ctx) defer func() { - // Remove the route when read loop ends (connection closed) c.clientRouter.removeConnRoute(conn) conn.Close() }() @@ -149,50 +133,25 @@ func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser xl.Warnf("client read error: %v", err) return } - if len(data) == 0 { continue } - switch { - case waterutil.IsIPv4(data): - header, err := ipv4.ParseHeader(data) - if err != nil { - xl.Warnf("parse ipv4 header error: %v", err) - continue - } - xl.Tracef("%s >> %s %d/%-4d %-4x %d", - header.Src, header.Dst, - header.Len, header.TotalLen, header.ID, header.Flags) - case waterutil.IsIPv6(data): - header, err := ipv6.ParseHeader(data) - if err != nil { - xl.Warnf("parse ipv6 header error: %v", err) - continue - } - xl.Tracef("%s >> %s %d %d", - header.Src, header.Dst, - header.PayloadLen, header.TrafficClass) - default: - xl.Tracef("unknown packet, discarded(%d)", len(data)) - continue - } + logPacketHeader(xl, data) xl.Tracef("vnet write to tun (client) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data)) - _, err = c.tun.Write(data) - if err != nil { + if _, err := c.tun.Write(data); err != nil { xl.Warnf("client write tun error: %v", err) } } } -// Server connection read loop +// readLoopServer reads packets from a server connection and writes them to the TUN device, +// while maintaining source IP to connection mappings. func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser, onClose func()) { xl := xlog.FromContextSafe(ctx) defer func() { - // Clean up all IP mappings associated with this connection when it closes c.serverRouter.cleanupConnIPs(conn) - // Call the provided callback upon closure if onClose != nil { onClose() } @@ -205,56 +164,36 @@ func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser xl.Warnf("server read error: %v", err) return } - if len(data) == 0 { continue } - // Register source IP to connection mapping - if waterutil.IsIPv4(data) || waterutil.IsIPv6(data) { - var src net.IP - if waterutil.IsIPv4(data) { - header, err := ipv4.ParseHeader(data) - if err == nil { - src = header.Src - c.serverRouter.registerSrcIP(src, conn) - } - } else { - header, err := ipv6.ParseHeader(data) - if err == nil { - src = header.Src - c.serverRouter.registerSrcIP(src, conn) - } - } - } + registerSourceIP(c.serverRouter, data) xl.Tracef("vnet write to tun (server) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data)) - _, err = c.tun.Write(data) - if err != nil { + if _, err := c.tun.Write(data); err != nil { xl.Warnf("server write tun error: %v", err) } } } -// RegisterClientRoute registers a client route (based on destination IP CIDR) -// and starts the read loop +// RegisterClientRoute adds a client route and starts the read loop for the connection. func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) { c.clientRouter.addRoute(name, routes, conn) go c.readLoopClient(ctx, conn) } -// UnregisterClientRoute Remove client route from routing table +// UnregisterClientRoute removes a client route. func (c *Controller) UnregisterClientRoute(name string) { c.clientRouter.delRoute(name) } -// StartServerConnReadLoop starts the read loop for a server connection -// (dynamically associates with source IPs) +// StartServerConnReadLoop starts the read loop for a server connection with cleanup on close. func (c *Controller) StartServerConnReadLoop(ctx context.Context, conn io.ReadWriteCloser, onClose func()) { go c.readLoopServer(ctx, conn, onClose) } -// ParseRoutes Convert route strings to IPNet objects +// ParseRoutes parses CIDR route strings into net.IPNet objects. func ParseRoutes(routeStrings []string) ([]net.IPNet, error) { routes := make([]net.IPNet, 0, len(routeStrings)) for _, r := range routeStrings { @@ -267,7 +206,38 @@ func ParseRoutes(routeStrings []string) ([]net.IPNet, error) { return routes, nil } -// Client router (based on destination IP routing) +// Helper to log IP packet header information +func logPacketHeader(xl xlog.Logger, data []byte) { + switch { + case waterutil.IsIPv4(data): + if header, err := ipv4.ParseHeader(data); err == nil { + xl.Tracef("%s >> %s %d/%-4d %-4x %d", header.Src, header.Dst, header.Len, header.TotalLen, header.ID, header.Flags) + } + case waterutil.IsIPv6(data): + if header, err := ipv6.ParseHeader(data); err == nil { + xl.Tracef("%s >> %s %d %d", header.Src, header.Dst, header.PayloadLen, header.TrafficClass) + } + default: + xl.Tracef("unknown packet, discarded(%d)", len(data)) + } +} + +// Helper to register source IP for server router +func registerSourceIP(r *serverRouter, data []byte) { + if waterutil.IsIPv4(data) { + if header, err := ipv4.ParseHeader(data); err == nil { + r.registerSrcIP(header.Src, nil) // nil or actual connection if available + } + } else if waterutil.IsIPv6(data) { + if header, err := ipv6.ParseHeader(data); err == nil { + r.registerSrcIP(header.Src, nil) + } + } +} + +// ----------- clientRouter ------------ + +// clientRouter routes packets based on destination IP. type clientRouter struct { routes map[string]*routeElement mu sync.RWMutex @@ -282,16 +252,13 @@ func newClientRouter() *clientRouter { func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) { r.mu.Lock() defer r.mu.Unlock() - r.routes[name] = &routeElement{ - name: name, - routes: routes, - conn: conn, - } + r.routes[name] = &routeElement{name: name, routes: routes, conn: conn} } func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) { r.mu.RLock() defer r.mu.RUnlock() + for _, re := range r.routes { for _, route := range re.routes { if route.Contains(dst) { @@ -299,6 +266,7 @@ func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) { } } } + return nil, fmt.Errorf("no route found for destination %s", dst) } @@ -311,6 +279,7 @@ func (r *clientRouter) delRoute(name string) { func (r *clientRouter) removeConnRoute(conn io.Writer) { r.mu.Lock() defer r.mu.Unlock() + for name, re := range r.routes { if re.conn == conn { delete(r.routes, name) @@ -319,9 +288,11 @@ func (r *clientRouter) removeConnRoute(conn io.Writer) { } } -// Server router (based solely on source IP routing) +// ----------- serverRouter ------------ + +// serverRouter routes packets based on source IP. type serverRouter struct { - srcIPConns map[string]io.Writer // Source IP string to connection mapping + srcIPConns map[string]io.Writer mu sync.RWMutex } @@ -334,6 +305,7 @@ func newServerRouter() *serverRouter { func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) { r.mu.RLock() defer r.mu.RUnlock() + conn, exists := r.srcIPConns[src.String()] if !exists { return nil, fmt.Errorf("no route found for source %s", src) @@ -348,16 +320,14 @@ func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) { existingConn, ok := r.srcIPConns[key] r.mu.RUnlock() - // If the entry exists and the connection is the same, no need to do anything. if ok && existingConn == conn { return } - // Acquire write lock to update the map. r.mu.Lock() defer r.mu.Unlock() - // Double-check after acquiring the write lock to handle potential race conditions. + // Double-check after locking to avoid race condition existingConn, ok = r.srcIPConns[key] if ok && existingConn == conn { return @@ -366,12 +336,10 @@ func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) { r.srcIPConns[key] = conn } -// cleanupConnIPs removes all IP mappings associated with the specified connection func (r *serverRouter) cleanupConnIPs(conn io.Writer) { r.mu.Lock() defer r.mu.Unlock() - // Find and delete all IP mappings pointing to this connection for ip, mappedConn := range r.srcIPConns { if mappedConn == conn { delete(r.srcIPConns, ip) @@ -379,6 +347,9 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) { } } +// ----------- routeElement ------------ + +// routeElement associates a route name with IP networks and a connection. type routeElement struct { name string routes []net.IPNet