in the stars

pull/4867/head
Ben.Laskin 2025-06-24 14:22:53 -04:00
parent f01a3c6c75
commit 1f59480ee9
1 changed files with 77 additions and 106 deletions

View File

@ -1,17 +1,3 @@
// Copyright 2025 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vnet
import (
@ -32,18 +18,18 @@ import (
"github.com/fatedier/frp/pkg/util/xlog"
)
const (
maxPacketSize = 1420
)
const maxPacketSize = 1420
// Controller manages the virtual network TUN interface and routes packets
type Controller struct {
addr string
tun io.ReadWriteCloser
clientRouter *clientRouter // Route based on destination IP (client mode)
serverRouter *serverRouter // Route based on source IP (server mode)
clientRouter *clientRouter // routes packets by destination IP (client mode)
serverRouter *serverRouter // routes packets by source IP (server mode)
}
// NewController creates a new Controller based on the provided configuration.
func NewController(cfg v1.VirtualNetConfig) *Controller {
return &Controller{
addr: cfg.Address,
@ -52,6 +38,7 @@ func NewController(cfg v1.VirtualNetConfig) *Controller {
}
}
// Init opens the TUN device with the configured address.
func (c *Controller) Init() error {
tunDevice, err := OpenTun(context.Background(), c.addr)
if err != nil {
@ -61,9 +48,9 @@ func (c *Controller) Init() error {
return nil
}
// Run continuously reads packets from the TUN device and processes them.
func (c *Controller) Run() error {
conn := c.tun
for {
buf := pool.GetBuf(maxPacketSize)
n, err := conn.Read(buf)
@ -78,11 +65,12 @@ func (c *Controller) Run() error {
}
}
// handlePacket processes a single packet. The caller is responsible for managing the buffer.
// handlePacket parses the packet header and forwards it to the appropriate route.
func (c *Controller) handlePacket(buf []byte) {
log.Tracef("vnet read from tun [%d]: %s", len(buf), base64.StdEncoding.EncodeToString(buf))
var src, dst net.IP
switch {
case waterutil.IsIPv4(buf):
header, err := ipv4.ParseHeader(buf)
@ -90,37 +78,33 @@ func (c *Controller) handlePacket(buf []byte) {
log.Warnf("parse ipv4 header error: %v", err)
return
}
src = header.Src
dst = header.Dst
log.Tracef("%s >> %s %d/%-4d %-4x %d",
header.Src, header.Dst,
header.Len, header.TotalLen, header.ID, header.Flags)
src, dst = header.Src, header.Dst
log.Tracef("%s >> %s %d/%-4d %-4x %d", src, dst, header.Len, header.TotalLen, header.ID, header.Flags)
case waterutil.IsIPv6(buf):
header, err := ipv6.ParseHeader(buf)
if err != nil {
log.Warnf("parse ipv6 header error: %v", err)
return
}
src = header.Src
dst = header.Dst
log.Tracef("%s >> %s %d %d",
header.Src, header.Dst,
header.PayloadLen, header.TrafficClass)
src, dst = header.Src, header.Dst
log.Tracef("%s >> %s %d %d", src, dst, header.PayloadLen, header.TrafficClass)
default:
log.Tracef("unknown packet, discarded(%d)", len(buf))
log.Tracef("unknown packet, discarded (%d bytes)", len(buf))
return
}
targetConn, err := c.clientRouter.findConn(dst)
if err == nil {
// Try client route (based on destination IP)
if targetConn, err := c.clientRouter.findConn(dst); err == nil {
if err := WriteMessage(targetConn, buf); err != nil {
log.Warnf("write to client target conn error: %v", err)
}
return
}
targetConn, err = c.serverRouter.findConnBySrc(dst)
if err == nil {
// Try server route (based on source IP)
if targetConn, err := c.serverRouter.findConnBySrc(dst); err == nil {
if err := WriteMessage(targetConn, buf); err != nil {
log.Warnf("write to server target conn error: %v", err)
}
@ -130,15 +114,15 @@ func (c *Controller) handlePacket(buf []byte) {
log.Tracef("no route found for packet from %s to %s", src, dst)
}
// Stop closes the TUN interface.
func (c *Controller) Stop() error {
return c.tun.Close()
}
// Client connection read loop
// readLoopClient reads packets from a client connection and writes them to the TUN device.
func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser) {
xl := xlog.FromContextSafe(ctx)
defer func() {
// Remove the route when read loop ends (connection closed)
c.clientRouter.removeConnRoute(conn)
conn.Close()
}()
@ -149,50 +133,25 @@ func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser
xl.Warnf("client read error: %v", err)
return
}
if len(data) == 0 {
continue
}
switch {
case waterutil.IsIPv4(data):
header, err := ipv4.ParseHeader(data)
if err != nil {
xl.Warnf("parse ipv4 header error: %v", err)
continue
}
xl.Tracef("%s >> %s %d/%-4d %-4x %d",
header.Src, header.Dst,
header.Len, header.TotalLen, header.ID, header.Flags)
case waterutil.IsIPv6(data):
header, err := ipv6.ParseHeader(data)
if err != nil {
xl.Warnf("parse ipv6 header error: %v", err)
continue
}
xl.Tracef("%s >> %s %d %d",
header.Src, header.Dst,
header.PayloadLen, header.TrafficClass)
default:
xl.Tracef("unknown packet, discarded(%d)", len(data))
continue
}
logPacketHeader(xl, data)
xl.Tracef("vnet write to tun (client) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
_, err = c.tun.Write(data)
if err != nil {
if _, err := c.tun.Write(data); err != nil {
xl.Warnf("client write tun error: %v", err)
}
}
}
// Server connection read loop
// readLoopServer reads packets from a server connection and writes them to the TUN device,
// while maintaining source IP to connection mappings.
func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser, onClose func()) {
xl := xlog.FromContextSafe(ctx)
defer func() {
// Clean up all IP mappings associated with this connection when it closes
c.serverRouter.cleanupConnIPs(conn)
// Call the provided callback upon closure
if onClose != nil {
onClose()
}
@ -205,56 +164,36 @@ func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser
xl.Warnf("server read error: %v", err)
return
}
if len(data) == 0 {
continue
}
// Register source IP to connection mapping
if waterutil.IsIPv4(data) || waterutil.IsIPv6(data) {
var src net.IP
if waterutil.IsIPv4(data) {
header, err := ipv4.ParseHeader(data)
if err == nil {
src = header.Src
c.serverRouter.registerSrcIP(src, conn)
}
} else {
header, err := ipv6.ParseHeader(data)
if err == nil {
src = header.Src
c.serverRouter.registerSrcIP(src, conn)
}
}
}
registerSourceIP(c.serverRouter, data)
xl.Tracef("vnet write to tun (server) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
_, err = c.tun.Write(data)
if err != nil {
if _, err := c.tun.Write(data); err != nil {
xl.Warnf("server write tun error: %v", err)
}
}
}
// RegisterClientRoute registers a client route (based on destination IP CIDR)
// and starts the read loop
// RegisterClientRoute adds a client route and starts the read loop for the connection.
func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) {
c.clientRouter.addRoute(name, routes, conn)
go c.readLoopClient(ctx, conn)
}
// UnregisterClientRoute Remove client route from routing table
// UnregisterClientRoute removes a client route.
func (c *Controller) UnregisterClientRoute(name string) {
c.clientRouter.delRoute(name)
}
// StartServerConnReadLoop starts the read loop for a server connection
// (dynamically associates with source IPs)
// StartServerConnReadLoop starts the read loop for a server connection with cleanup on close.
func (c *Controller) StartServerConnReadLoop(ctx context.Context, conn io.ReadWriteCloser, onClose func()) {
go c.readLoopServer(ctx, conn, onClose)
}
// ParseRoutes Convert route strings to IPNet objects
// ParseRoutes parses CIDR route strings into net.IPNet objects.
func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
routes := make([]net.IPNet, 0, len(routeStrings))
for _, r := range routeStrings {
@ -267,7 +206,38 @@ func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
return routes, nil
}
// Client router (based on destination IP routing)
// Helper to log IP packet header information
func logPacketHeader(xl xlog.Logger, data []byte) {
switch {
case waterutil.IsIPv4(data):
if header, err := ipv4.ParseHeader(data); err == nil {
xl.Tracef("%s >> %s %d/%-4d %-4x %d", header.Src, header.Dst, header.Len, header.TotalLen, header.ID, header.Flags)
}
case waterutil.IsIPv6(data):
if header, err := ipv6.ParseHeader(data); err == nil {
xl.Tracef("%s >> %s %d %d", header.Src, header.Dst, header.PayloadLen, header.TrafficClass)
}
default:
xl.Tracef("unknown packet, discarded(%d)", len(data))
}
}
// Helper to register source IP for server router
func registerSourceIP(r *serverRouter, data []byte) {
if waterutil.IsIPv4(data) {
if header, err := ipv4.ParseHeader(data); err == nil {
r.registerSrcIP(header.Src, nil) // nil or actual connection if available
}
} else if waterutil.IsIPv6(data) {
if header, err := ipv6.ParseHeader(data); err == nil {
r.registerSrcIP(header.Src, nil)
}
}
}
// ----------- clientRouter ------------
// clientRouter routes packets based on destination IP.
type clientRouter struct {
routes map[string]*routeElement
mu sync.RWMutex
@ -282,16 +252,13 @@ func newClientRouter() *clientRouter {
func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) {
r.mu.Lock()
defer r.mu.Unlock()
r.routes[name] = &routeElement{
name: name,
routes: routes,
conn: conn,
}
r.routes[name] = &routeElement{name: name, routes: routes, conn: conn}
}
func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, re := range r.routes {
for _, route := range re.routes {
if route.Contains(dst) {
@ -299,6 +266,7 @@ func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
}
}
}
return nil, fmt.Errorf("no route found for destination %s", dst)
}
@ -311,6 +279,7 @@ func (r *clientRouter) delRoute(name string) {
func (r *clientRouter) removeConnRoute(conn io.Writer) {
r.mu.Lock()
defer r.mu.Unlock()
for name, re := range r.routes {
if re.conn == conn {
delete(r.routes, name)
@ -319,9 +288,11 @@ func (r *clientRouter) removeConnRoute(conn io.Writer) {
}
}
// Server router (based solely on source IP routing)
// ----------- serverRouter ------------
// serverRouter routes packets based on source IP.
type serverRouter struct {
srcIPConns map[string]io.Writer // Source IP string to connection mapping
srcIPConns map[string]io.Writer
mu sync.RWMutex
}
@ -334,6 +305,7 @@ func newServerRouter() *serverRouter {
func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) {
r.mu.RLock()
defer r.mu.RUnlock()
conn, exists := r.srcIPConns[src.String()]
if !exists {
return nil, fmt.Errorf("no route found for source %s", src)
@ -348,16 +320,14 @@ func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
existingConn, ok := r.srcIPConns[key]
r.mu.RUnlock()
// If the entry exists and the connection is the same, no need to do anything.
if ok && existingConn == conn {
return
}
// Acquire write lock to update the map.
r.mu.Lock()
defer r.mu.Unlock()
// Double-check after acquiring the write lock to handle potential race conditions.
// Double-check after locking to avoid race condition
existingConn, ok = r.srcIPConns[key]
if ok && existingConn == conn {
return
@ -366,12 +336,10 @@ func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
r.srcIPConns[key] = conn
}
// cleanupConnIPs removes all IP mappings associated with the specified connection
func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
r.mu.Lock()
defer r.mu.Unlock()
// Find and delete all IP mappings pointing to this connection
for ip, mappedConn := range r.srcIPConns {
if mappedConn == conn {
delete(r.srcIPConns, ip)
@ -379,6 +347,9 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
}
}
// ----------- routeElement ------------
// routeElement associates a route name with IP networks and a connection.
type routeElement struct {
name string
routes []net.IPNet