From 9ece38fca97876dd435a4194b769c93f191d3dce Mon Sep 17 00:00:00 2001
From: Fionera <fionera@fionera.de>
Date: Mon, 21 Mar 2022 11:38:05 +0100
Subject: [PATCH] refactor: Use netlink for tcpstat collector

Signed-off-by: Tim Windelschmidt <t.windelschmidt@babiel.com>
---
 collector/fixtures/proc/net/tcpstat |   3 -
 collector/tcpstat_linux.go          | 137 ++++++++++++++++++----------
 collector/tcpstat_linux_test.go     | 125 +++++++++----------------
 go.mod                              |   1 +
 4 files changed, 133 insertions(+), 133 deletions(-)
 delete mode 100644 collector/fixtures/proc/net/tcpstat

diff --git a/collector/fixtures/proc/net/tcpstat b/collector/fixtures/proc/net/tcpstat
deleted file mode 100644
index 352c00bb..00000000
--- a/collector/fixtures/proc/net/tcpstat
+++ /dev/null
@@ -1,3 +0,0 @@
-  sl  local_address rem_address   st tx_queue rx_queue tr tm->when retrnsmt   uid  timeout inode                                                     
-   0: 00000000:0016 00000000:0000 0A 00000015:00000000 00:00000000 00000000     0        0 2740 1 ffff88003d3af3c0 100 0 0 10 0                      
-   1: 0F02000A:0016 0202000A:8B6B 01 00000015:00000001 02:000AC99B 00000000     0        0 3652 4 ffff88003d3ae040 21 4 31 47 46                     
diff --git a/collector/tcpstat_linux.go b/collector/tcpstat_linux.go
index 47c3f3e0..99e33bc6 100644
--- a/collector/tcpstat_linux.go
+++ b/collector/tcpstat_linux.go
@@ -18,13 +18,12 @@ package collector
 
 import (
 	"fmt"
-	"io"
-	"io/ioutil"
 	"os"
-	"strconv"
-	"strings"
+	"syscall"
+	"unsafe"
 
 	"github.com/go-kit/log"
+	"github.com/mdlayher/netlink"
 	"github.com/prometheus/client_golang/prometheus"
 )
 
@@ -80,16 +79,64 @@ func NewTCPStatCollector(logger log.Logger) (Collector, error) {
 	}, nil
 }
 
+// InetDiagSockID (inet_diag_sockid) contains the socket identity.
+// https://github.com/torvalds/linux/blob/v4.0/include/uapi/linux/inet_diag.h#L13
+type InetDiagSockID struct {
+	SourcePort [2]byte
+	DestPort   [2]byte
+	SourceIP   [4][4]byte
+	DestIP     [4][4]byte
+	Interface  uint32
+	Cookie     [2]uint32
+}
+
+// InetDiagReqV2 (inet_diag_req_v2) is used to request diagnostic data.
+// https://github.com/torvalds/linux/blob/v4.0/include/uapi/linux/inet_diag.h#L37
+type InetDiagReqV2 struct {
+	Family   uint8
+	Protocol uint8
+	Ext      uint8
+	Pad      uint8
+	States   uint32
+	ID       InetDiagSockID
+}
+
+const sizeOfDiagRequest = 0x38
+
+func (req *InetDiagReqV2) Serialize() []byte {
+	return (*(*[sizeOfDiagRequest]byte)(unsafe.Pointer(req)))[:]
+}
+
+func (req *InetDiagReqV2) Len() int {
+	return sizeOfDiagRequest
+}
+
+type InetDiagMsg struct {
+	Family  uint8
+	State   uint8
+	Timer   uint8
+	Retrans uint8
+	ID      InetDiagSockID
+	Expires uint32
+	RQueue  uint32
+	WQueue  uint32
+	UID     uint32
+	Inode   uint32
+}
+
+func parseInetDiagMsg(b []byte) *InetDiagMsg {
+	return (*InetDiagMsg)(unsafe.Pointer(&b[0]))
+}
+
 func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error {
-	tcpStats, err := getTCPStats(procFilePath("net/tcp"))
+	tcpStats, err := getTCPStats(syscall.AF_INET)
 	if err != nil {
 		return fmt.Errorf("couldn't get tcpstats: %w", err)
 	}
 
 	// if enabled ipv6 system
-	tcp6File := procFilePath("net/tcp6")
-	if _, hasIPv6 := os.Stat(tcp6File); hasIPv6 == nil {
-		tcp6Stats, err := getTCPStats(tcp6File)
+	if _, hasIPv6 := os.Stat(procFilePath("net/tcp6")); hasIPv6 == nil {
+		tcp6Stats, err := getTCPStats(syscall.AF_INET6)
 		if err != nil {
 			return fmt.Errorf("couldn't get tcp6stats: %w", err)
 		}
@@ -102,59 +149,51 @@ func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error {
 	for st, value := range tcpStats {
 		ch <- c.desc.mustNewConstMetric(value, st.String())
 	}
+
 	return nil
 }
 
-func getTCPStats(statsFile string) (map[tcpConnectionState]float64, error) {
-	file, err := os.Open(statsFile)
+func getTCPStats(family uint8) (map[tcpConnectionState]float64, error) {
+	const TCPFAll = 0xFFF
+	const InetDiagInfo = 2
+	const SockDiagByFamily = 20
+
+	conn, err := netlink.Dial(syscall.NETLINK_INET_DIAG, nil)
+	if err != nil {
+		return nil, fmt.Errorf("couldn't connect netlink: %w", err)
+	}
+	defer conn.Close()
+
+	msg := netlink.Message{
+		Header: netlink.Header{
+			Type:  SockDiagByFamily,
+			Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_DUMP,
+		},
+		Data: (&InetDiagReqV2{
+			Family:   family,
+			Protocol: syscall.IPPROTO_TCP,
+			States:   TCPFAll,
+			Ext:      0 | 1<<(InetDiagInfo-1),
+		}).Serialize(),
+	}
+
+	messages, err := conn.Execute(msg)
 	if err != nil {
 		return nil, err
 	}
-	defer file.Close()
 
-	return parseTCPStats(file)
+	return parseTCPStats(messages)
 }
 
-func parseTCPStats(r io.Reader) (map[tcpConnectionState]float64, error) {
+func parseTCPStats(msgs []netlink.Message) (map[tcpConnectionState]float64, error) {
 	tcpStats := map[tcpConnectionState]float64{}
-	contents, err := ioutil.ReadAll(r)
-	if err != nil {
-		return nil, err
-	}
 
-	for _, line := range strings.Split(string(contents), "\n")[1:] {
-		parts := strings.Fields(line)
-		if len(parts) == 0 {
-			continue
-		}
-		if len(parts) < 5 {
-			return nil, fmt.Errorf("invalid TCP stats line: %q", line)
-		}
-
-		qu := strings.Split(parts[4], ":")
-		if len(qu) < 2 {
-			return nil, fmt.Errorf("cannot parse tx_queues and rx_queues: %q", line)
-		}
-
-		tx, err := strconv.ParseUint(qu[0], 16, 64)
-		if err != nil {
-			return nil, err
-		}
-		tcpStats[tcpConnectionState(tcpTxQueuedBytes)] += float64(tx)
-
-		rx, err := strconv.ParseUint(qu[1], 16, 64)
-		if err != nil {
-			return nil, err
-		}
-		tcpStats[tcpConnectionState(tcpRxQueuedBytes)] += float64(rx)
-
-		st, err := strconv.ParseInt(parts[3], 16, 8)
-		if err != nil {
-			return nil, err
-		}
-
-		tcpStats[tcpConnectionState(st)]++
+	for _, m := range msgs {
+		msg := parseInetDiagMsg(m.Data)
 
+		tcpStats[tcpTxQueuedBytes] += float64(msg.WQueue)
+		tcpStats[tcpRxQueuedBytes] += float64(msg.RQueue)
+		tcpStats[tcpConnectionState(msg.State)]++
 	}
 
 	return tcpStats, nil
diff --git a/collector/tcpstat_linux_test.go b/collector/tcpstat_linux_test.go
index b609b846..37dc1eee 100644
--- a/collector/tcpstat_linux_test.go
+++ b/collector/tcpstat_linux_test.go
@@ -14,66 +14,56 @@
 package collector
 
 import (
-	"os"
-	"strings"
+	"bytes"
+	"encoding/binary"
+	"syscall"
 	"testing"
+
+	"github.com/mdlayher/netlink"
 )
 
-func Test_parseTCPStatsError(t *testing.T) {
-	tests := []struct {
-		name string
-		in   string
-	}{
-		{
-			name: "too few fields",
-			in:   "sl  local_address\n  0: 00000000:0016",
-		},
-		{
-			name: "missing colon in tx-rx field",
-			in: "sl  local_address rem_address   st tx_queue rx_queue\n" +
-				" 1: 0F02000A:0016 0202000A:8B6B 01 0000000000000001",
-		},
-		{
-			name: "tx parsing issue",
-			in: "sl  local_address rem_address   st tx_queue rx_queue\n" +
-				" 1: 0F02000A:0016 0202000A:8B6B 01 0000000x:00000001",
-		},
-		{
-			name: "rx parsing issue",
-			in: "sl  local_address rem_address   st tx_queue rx_queue\n" +
-				" 1: 0F02000A:0016 0202000A:8B6B 01 00000000:0000000x",
-		},
-		{
-			name: "state parsing issue",
-			in: "sl  local_address rem_address   st tx_queue rx_queue\n" +
-				" 1: 0F02000A:0016 0202000A:8B6B 0H 00000000:00000001",
-		},
-	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			if _, err := parseTCPStats(strings.NewReader(tt.in)); err == nil {
-				t.Fatal("expected an error, but none occurred")
-			}
-		})
-	}
-}
-
-func TestTCPStat(t *testing.T) {
-
-	noFile, _ := os.Open("follow the white rabbit")
-	defer noFile.Close()
-
-	if _, err := parseTCPStats(noFile); err == nil {
-		t.Fatal("expected an error, but none occurred")
+func Test_parseTCPStats(t *testing.T) {
+	encode := func(m InetDiagMsg) []byte {
+		var buf bytes.Buffer
+		err := binary.Write(&buf, binary.LittleEndian, m)
+		if err != nil {
+			panic(err)
+		}
+		return buf.Bytes()
 	}
 
-	file, err := os.Open("fixtures/proc/net/tcpstat")
-	if err != nil {
-		t.Fatal(err)
+	msg := []netlink.Message{
+		{
+			Data: encode(InetDiagMsg{
+				Family:  syscall.AF_INET,
+				State:   uint8(tcpEstablished),
+				Timer:   0,
+				Retrans: 0,
+				ID:      InetDiagSockID{},
+				Expires: 0,
+				RQueue:  11,
+				WQueue:  21,
+				UID:     0,
+				Inode:   0,
+			}),
+		},
+		{
+			Data: encode(InetDiagMsg{
+				Family:  syscall.AF_INET,
+				State:   uint8(tcpListen),
+				Timer:   0,
+				Retrans: 0,
+				ID:      InetDiagSockID{},
+				Expires: 0,
+				RQueue:  11,
+				WQueue:  21,
+				UID:     0,
+				Inode:   0,
+			}),
+		},
 	}
-	defer file.Close()
 
-	tcpStats, err := parseTCPStats(file)
+	tcpStats, err := parseTCPStats(msg)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -89,35 +79,8 @@ func TestTCPStat(t *testing.T) {
 	if want, got := 42, int(tcpStats[tcpTxQueuedBytes]); want != got {
 		t.Errorf("want tcpstat number of bytes in tx queue %d, got %d", want, got)
 	}
-	if want, got := 1, int(tcpStats[tcpRxQueuedBytes]); want != got {
+	if want, got := 22, int(tcpStats[tcpRxQueuedBytes]); want != got {
 		t.Errorf("want tcpstat number of bytes in rx queue %d, got %d", want, got)
 	}
 
 }
-
-func Test_getTCPStats(t *testing.T) {
-	type args struct {
-		statsFile string
-	}
-	tests := []struct {
-		name    string
-		args    args
-		wantErr bool
-	}{
-		{
-			name:    "file not found",
-			args:    args{statsFile: "somewhere over the rainbow"},
-			wantErr: true,
-		},
-	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			_, err := getTCPStats(tt.args.statsFile)
-			if (err != nil) != tt.wantErr {
-				t.Errorf("getTCPStats() error = %v, wantErr %v", err, tt.wantErr)
-				return
-			}
-			// other cases are covered by TestTCPStat()
-		})
-	}
-}
diff --git a/go.mod b/go.mod
index a78dc8ff..d1be25f0 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
 	github.com/jsimonetti/rtnetlink v1.1.1
 	github.com/lufia/iostat v1.2.1
 	github.com/mattn/go-xmlrpc v0.0.3
+	github.com/mdlayher/netlink v1.6.0
 	github.com/mdlayher/wifi v0.0.0-20220320220353-954ff73a19a5
 	github.com/prometheus/client_golang v1.12.1
 	github.com/prometheus/client_model v0.2.0