initial commit

This commit is contained in:
Karol Broda
2025-12-16 22:42:49 +01:00
commit 371f4d13a6
61 changed files with 6872 additions and 0 deletions

View File

@@ -0,0 +1,506 @@
package collector
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
// Collector interface defines methods for collecting connection data
type Collector interface {
GetConnections() ([]Connection, error)
}
// DefaultCollector implements the Collector interface using /proc
type DefaultCollector struct{}
// Global collector instance (can be overridden for testing)
var globalCollector Collector = &DefaultCollector{}
// SetCollector sets the global collector instance
func SetCollector(collector Collector) {
globalCollector = collector
}
// GetCollector returns the current global collector instance
func GetCollector() Collector {
return globalCollector
}
// GetConnections fetches all network connections using the global collector
func GetConnections() ([]Connection, error) {
return globalCollector.GetConnections()
}
// GetConnections fetches all network connections by parsing /proc files.
func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
if runtime.GOOS != "linux" {
return nil, fmt.Errorf("proc-based collector only supports Linux")
}
// Build map of inode -> process info by scanning /proc
inodeMap, err := buildInodeToProcessMap()
if err != nil {
return nil, fmt.Errorf("failed to build inode map: %w", err)
}
var connections []Connection
// Parse TCP connections
tcpConns, err := parseProcNet("/proc/net/tcp", "tcp", 4, inodeMap)
if err == nil {
connections = append(connections, tcpConns...)
}
tcpConns6, err := parseProcNet("/proc/net/tcp6", "tcp6", 6, inodeMap)
if err == nil {
connections = append(connections, tcpConns6...)
}
// Parse UDP connections
udpConns, err := parseProcNet("/proc/net/udp", "udp", 4, inodeMap)
if err == nil {
connections = append(connections, udpConns...)
}
udpConns6, err := parseProcNet("/proc/net/udp6", "udp6", 6, inodeMap)
if err == nil {
connections = append(connections, udpConns6...)
}
return connections, nil
}
// GetAllConnections returns both network and Unix domain socket connections
func GetAllConnections() ([]Connection, error) {
// Get network connections
networkConns, err := GetConnections()
if err != nil {
return nil, err
}
// Get Unix sockets (only on Linux)
if runtime.GOOS == "linux" {
unixConns, err := GetUnixSockets()
if err == nil {
networkConns = append(networkConns, unixConns...)
}
}
return networkConns, nil
}
func FilterConnections(conns []Connection, filters FilterOptions) []Connection {
if filters.IsEmpty() {
return conns
}
filtered := make([]Connection, 0)
for _, conn := range conns {
if filters.Matches(conn) {
filtered = append(filtered, conn)
}
}
return filtered
}
// processInfo holds information about a process
type processInfo struct {
pid int
command string
uid int
user string
}
// buildInodeToProcessMap scans /proc to map socket inodes to processes
func buildInodeToProcessMap() (map[int64]*processInfo, error) {
inodeMap := make(map[int64]*processInfo)
procDir, err := os.Open("/proc")
if err != nil {
return nil, err
}
defer procDir.Close()
entries, err := procDir.Readdir(-1)
if err != nil {
return nil, err
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// check if directory name is a number (pid)
pidStr := entry.Name()
pid, err := strconv.Atoi(pidStr)
if err != nil {
continue
}
// get process info
procInfo, err := getProcessInfo(pid)
if err != nil {
continue
}
// scan /proc/[pid]/fd/ for socket file descriptors
fdDir := filepath.Join("/proc", pidStr, "fd")
fdEntries, err := os.ReadDir(fdDir)
if err != nil {
continue
}
for _, fdEntry := range fdEntries {
fdPath := filepath.Join(fdDir, fdEntry.Name())
link, err := os.Readlink(fdPath)
if err != nil {
continue
}
// socket inodes look like: socket:[12345]
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
inodeStr := link[8 : len(link)-1]
inode, err := strconv.ParseInt(inodeStr, 10, 64)
if err != nil {
continue
}
inodeMap[inode] = procInfo
}
}
}
return inodeMap, nil
}
// getProcessInfo reads process information from /proc/[pid]/
func getProcessInfo(pid int) (*processInfo, error) {
info := &processInfo{pid: pid}
// prefer /proc/[pid]/comm as it's always just the command name
commPath := filepath.Join("/proc", strconv.Itoa(pid), "comm")
commData, err := os.ReadFile(commPath)
if err == nil && len(commData) > 0 {
info.command = strings.TrimSpace(string(commData))
}
// if comm is not available, try cmdline
if info.command == "" {
cmdlinePath := filepath.Join("/proc", strconv.Itoa(pid), "cmdline")
cmdlineData, err := os.ReadFile(cmdlinePath)
if err != nil {
return nil, err
}
// cmdline is null-separated, take first part
if len(cmdlineData) > 0 {
parts := bytes.Split(cmdlineData, []byte{0})
if len(parts) > 0 && len(parts[0]) > 0 {
fullPath := string(parts[0])
// extract basename from full path
baseName := filepath.Base(fullPath)
// if basename contains spaces (single-string cmdline), take first word
if strings.Contains(baseName, " ") {
baseName = strings.Fields(baseName)[0]
}
info.command = baseName
}
}
}
// read UID from /proc/[pid]/status
statusPath := filepath.Join("/proc", strconv.Itoa(pid), "status")
statusFile, err := os.Open(statusPath)
if err != nil {
return info, nil
}
defer statusFile.Close()
scanner := bufio.NewScanner(statusFile)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "Uid:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
uid, err := strconv.Atoi(fields[1])
if err == nil {
info.uid = uid
// get username from uid
u, err := user.LookupId(strconv.Itoa(uid))
if err == nil {
info.user = u.Username
} else {
info.user = strconv.Itoa(uid)
}
}
}
break
}
}
return info, nil
}
// parseProcNet parses a /proc/net/tcp or /proc/net/udp file
func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*processInfo) ([]Connection, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var connections []Connection
scanner := bufio.NewScanner(file)
// skip header
scanner.Scan()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
fields := strings.Fields(line)
if len(fields) < 10 {
continue
}
// parse local address and port
localAddr, localPort, err := parseHexAddr(fields[1])
if err != nil {
continue
}
// parse remote address and port
remoteAddr, remotePort, err := parseHexAddr(fields[2])
if err != nil {
continue
}
// parse state (field 3)
stateHex := fields[3]
state := parseState(stateHex, proto)
// parse inode (field 9)
inode, _ := strconv.ParseInt(fields[9], 10, 64)
conn := Connection{
TS: time.Now(),
Proto: proto,
IPVersion: fmt.Sprintf("IPv%d", ipVersion),
State: state,
Laddr: localAddr,
Lport: localPort,
Raddr: remoteAddr,
Rport: remotePort,
Inode: inode,
}
// add process info if available
if procInfo, exists := inodeMap[inode]; exists {
conn.PID = procInfo.pid
conn.Process = procInfo.command
conn.UID = procInfo.uid
conn.User = procInfo.user
}
// determine interface
conn.Interface = guessNetworkInterface(localAddr, nil)
connections = append(connections, conn)
}
return connections, scanner.Err()
}
// parseState converts hex state value to string
func parseState(hexState, proto string) string {
state, err := strconv.ParseInt(hexState, 16, 32)
if err != nil {
return ""
}
// TCP states
tcpStates := map[int64]string{
0x01: "ESTABLISHED",
0x02: "SYN_SENT",
0x03: "SYN_RECV",
0x04: "FIN_WAIT1",
0x05: "FIN_WAIT2",
0x06: "TIME_WAIT",
0x07: "CLOSE",
0x08: "CLOSE_WAIT",
0x09: "LAST_ACK",
0x0A: "LISTEN",
0x0B: "CLOSING",
}
if strings.HasPrefix(proto, "tcp") {
if s, exists := tcpStates[state]; exists {
return s
}
} else {
// UDP doesn't have states in the same way
if state == 0x07 {
return "CLOSE"
}
return ""
}
return ""
}
// parseHexAddr parses hex-encoded address:port from /proc/net files
func parseHexAddr(hexAddr string) (string, int, error) {
parts := strings.Split(hexAddr, ":")
if len(parts) != 2 {
return "", 0, fmt.Errorf("invalid address format")
}
hexIP := parts[0]
// parse hex port
port, err := strconv.ParseInt(parts[1], 16, 32)
if err != nil {
return "", 0, err
}
if len(hexIP) == 8 {
// IPv4 (stored in little-endian)
ip1, _ := strconv.ParseInt(hexIP[6:8], 16, 32)
ip2, _ := strconv.ParseInt(hexIP[4:6], 16, 32)
ip3, _ := strconv.ParseInt(hexIP[2:4], 16, 32)
ip4, _ := strconv.ParseInt(hexIP[0:2], 16, 32)
addr := fmt.Sprintf("%d.%d.%d.%d", ip1, ip2, ip3, ip4)
// handle wildcard address
if addr == "0.0.0.0" {
addr = "*"
}
return addr, int(port), nil
} else if len(hexIP) == 32 {
// IPv6 (stored in little-endian per 32-bit word)
var ipv6Parts []string
for i := 0; i < 32; i += 8 {
word := hexIP[i : i+8]
// reverse byte order within each 32-bit word
p1 := word[6:8] + word[4:6] + word[2:4] + word[0:2]
ipv6Parts = append(ipv6Parts, p1)
}
// convert to standard IPv6 notation
fullAddr := strings.Join(ipv6Parts, "")
var formatted []string
for i := 0; i < len(fullAddr); i += 4 {
formatted = append(formatted, fullAddr[i:i+4])
}
addr := strings.Join(formatted, ":")
// simplify IPv6 address
addr = simplifyIPv6(addr)
// handle wildcard address
if addr == "::" || addr == "0:0:0:0:0:0:0:0" {
addr = "*"
}
return addr, int(port), nil
}
return "", 0, fmt.Errorf("unsupported address format")
}
// simplifyIPv6 simplifies IPv6 address notation
func simplifyIPv6(addr string) string {
// remove leading zeros from each group
parts := strings.Split(addr, ":")
for i, part := range parts {
// convert to int and back to remove leading zeros
val, err := strconv.ParseInt(part, 16, 64)
if err == nil {
parts[i] = strconv.FormatInt(val, 16)
}
}
return strings.Join(parts, ":")
}
func guessNetworkInterface(addr string, interfaces map[string]string) string {
// Simple heuristic - try to match common interface patterns
if addr == "127.0.0.1" || addr == "::1" {
return "lo"
}
// Check if it's a private network address
ip := net.ParseIP(addr)
if ip != nil {
if ip.IsLoopback() {
return "lo"
}
// More sophisticated interface detection would require routing table analysis
// For now, return a placeholder
if ip.To4() != nil {
return "eth0" // Common default for IPv4
} else {
return "eth0" // Common default for IPv6
}
}
return ""
}
// Add Unix socket support
func GetUnixSockets() ([]Connection, error) {
connections := []Connection{}
// Parse /proc/net/unix for Unix domain sockets
file, err := os.Open("/proc/net/unix")
if err != nil {
return connections, nil // silently fail on non-Linux systems
}
defer file.Close()
scanner := bufio.NewScanner(file)
// Skip header
scanner.Scan()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
fields := strings.Fields(line)
if len(fields) < 7 {
continue
}
// Parse Unix socket information
inode, _ := strconv.ParseInt(fields[6], 10, 64)
path := ""
if len(fields) > 7 {
path = fields[7]
}
conn := Connection{
TS: time.Now(),
Proto: "unix",
Laddr: path,
Raddr: "",
State: "CONNECTED", // Simplified
Inode: inode,
Interface: "unix",
}
connections = append(connections, conn)
}
return connections, nil
}

View File

@@ -0,0 +1,16 @@
package collector
import (
"testing"
)
func TestGetConnections(t *testing.T) {
// integration test to verify /proc parsing works
conns, err := GetConnections()
if err != nil {
t.Fatalf("GetConnections() returned an error: %v", err)
}
// connections are dynamic, so just verify function succeeded
t.Logf("Successfully got %d connections", len(conns))
}

View File

@@ -0,0 +1,129 @@
package collector
import (
"strings"
"time"
)
type FilterOptions struct {
Proto string
State string
Pid int
Proc string
Lport int
Rport int
User string
UID int
Laddr string
Raddr string
Contains string
IPv4 bool
IPv6 bool
Interface string
Mark string
Namespace string
Inode int64
Since time.Time
SinceRel time.Duration
}
func (f *FilterOptions) IsEmpty() bool {
return f.Proto == "" && f.State == "" && f.Pid == 0 && f.Proc == "" &&
f.Lport == 0 && f.Rport == 0 && f.User == "" && f.UID == 0 &&
f.Laddr == "" && f.Raddr == "" && f.Contains == "" &&
f.Interface == "" && f.Mark == "" && f.Namespace == "" && f.Inode == 0 &&
f.Since.IsZero() && f.SinceRel == 0 && !f.IPv4 && !f.IPv6
}
func (f *FilterOptions) Matches(c Connection) bool {
if f.Proto != "" && !strings.EqualFold(c.Proto, f.Proto) {
return false
}
if f.State != "" && !strings.EqualFold(c.State, f.State) {
return false
}
if f.Pid != 0 && c.PID != f.Pid {
return false
}
if f.Proc != "" && !containsIgnoreCase(c.Process, f.Proc) {
return false
}
if f.Lport != 0 && c.Lport != f.Lport {
return false
}
if f.Rport != 0 && c.Rport != f.Rport {
return false
}
if f.User != "" && !strings.EqualFold(c.User, f.User) {
return false
}
if f.UID != 0 && c.UID != f.UID {
return false
}
if f.Laddr != "" && !strings.EqualFold(c.Laddr, f.Laddr) {
return false
}
if f.Raddr != "" && !strings.EqualFold(c.Raddr, f.Raddr) {
return false
}
if f.Contains != "" && !matchesContains(c, f.Contains) {
return false
}
if f.IPv4 && c.IPVersion != "IPv4" {
return false
}
if f.IPv6 && c.IPVersion != "IPv6" {
return false
}
if f.Interface != "" && !strings.EqualFold(c.Interface, f.Interface) {
return false
}
if f.Mark != "" && !strings.EqualFold(c.Mark, f.Mark) {
return false
}
if f.Namespace != "" && !strings.EqualFold(c.Namespace, f.Namespace) {
return false
}
if f.Inode != 0 && c.Inode != f.Inode {
return false
}
if !f.Since.IsZero() && c.TS.Before(f.Since) {
return false
}
if f.SinceRel != 0 {
threshold := time.Now().Add(-f.SinceRel)
if c.TS.Before(threshold) {
return false
}
}
return true
}
func containsIgnoreCase(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}
func matchesContains(c Connection, query string) bool {
q := strings.ToLower(query)
return containsIgnoreCase(c.Process, q) ||
containsIgnoreCase(c.Laddr, q) ||
containsIgnoreCase(c.Raddr, q) ||
containsIgnoreCase(c.User, q)
}
// ParseTimeFilter parses a time filter string (RFC3339 or relative like "5s", "2m", "1h")
func ParseTimeFilter(timeStr string) (time.Time, time.Duration, error) {
// Try parsing as RFC3339 first
if t, err := time.Parse(time.RFC3339, timeStr); err == nil {
return t, 0, nil
}
// Try parsing as relative duration
if dur, err := time.ParseDuration(timeStr); err == nil {
return time.Time{}, dur, nil
}
return time.Time{}, 0, nil // Invalid format, but don't error
}

View File

@@ -0,0 +1,44 @@
package collector
import (
"testing"
)
func TestFilterConnections(t *testing.T) {
conns := []Connection{
{PID: 1, Process: "proc1", User: "user1", Proto: "tcp", State: "ESTABLISHED", Laddr: "1.1.1.1", Lport: 80, Raddr: "2.2.2.2", Rport: 1234},
{PID: 2, Process: "proc2", User: "user2", Proto: "udp", State: "LISTEN", Laddr: "3.3.3.3", Lport: 53, Raddr: "*", Rport: 0},
{PID: 3, Process: "proc1_extra", User: "user1", Proto: "tcp", State: "ESTABLISHED", Laddr: "4.4.4.4", Lport: 443, Raddr: "5.5.5.5", Rport: 5678},
}
testCases := []struct {
name string
filters FilterOptions
expected int
}{
{"No filters", FilterOptions{}, 3},
{"Filter by proto tcp", FilterOptions{Proto: "tcp"}, 2},
{"Filter by proto udp", FilterOptions{Proto: "udp"}, 1},
{"Filter by state", FilterOptions{State: "ESTABLISHED"}, 2},
{"Filter by pid", FilterOptions{Pid: 2}, 1},
{"Filter by proc", FilterOptions{Proc: "proc1"}, 2},
{"Filter by lport", FilterOptions{Lport: 80}, 1},
{"Filter by rport", FilterOptions{Rport: 1234}, 1},
{"Filter by user", FilterOptions{User: "user1"}, 2},
{"Filter by laddr", FilterOptions{Laddr: "1.1.1.1"}, 1},
{"Filter by raddr", FilterOptions{Raddr: "5.5.5.5"}, 1},
{"Filter by contains proc", FilterOptions{Contains: "proc2"}, 1},
{"Filter by contains addr", FilterOptions{Contains: "3.3.3.3"}, 1},
{"Combined filter", FilterOptions{Proto: "tcp", State: "ESTABLISHED"}, 2},
{"No match", FilterOptions{Proto: "tcp", State: "LISTEN"}, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
filtered := FilterConnections(conns, tc.filters)
if len(filtered) != tc.expected {
t.Errorf("Expected %d connections, but got %d", tc.expected, len(filtered))
}
})
}
}

403
internal/collector/mock.go Normal file
View File

@@ -0,0 +1,403 @@
package collector
import (
"encoding/json"
"os"
"time"
)
// MockCollector provides deterministic test data for testing
// It implements the Collector interface
type MockCollector struct {
connections []Connection
}
// NewMockCollector creates a new mock collector with default test data
func NewMockCollector() *MockCollector {
return &MockCollector{
connections: getDefaultTestConnections(),
}
}
// NewMockCollectorFromFile creates a mock collector from a JSON fixture file
func NewMockCollectorFromFile(filename string) (*MockCollector, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
var connections []Connection
if err := json.Unmarshal(data, &connections); err != nil {
return nil, err
}
return &MockCollector{
connections: connections,
}, nil
}
// GetConnections returns the mock connections
func (m *MockCollector) GetConnections() ([]Connection, error) {
// Return a copy to avoid mutation
result := make([]Connection, len(m.connections))
copy(result, m.connections)
return result, nil
}
// AddConnection adds a connection to the mock data
func (m *MockCollector) AddConnection(conn Connection) {
m.connections = append(m.connections, conn)
}
// SetConnections replaces all connections with the provided slice
func (m *MockCollector) SetConnections(connections []Connection) {
m.connections = make([]Connection, len(connections))
copy(m.connections, connections)
}
// SaveToFile saves the current connections to a JSON file
func (m *MockCollector) SaveToFile(filename string) error {
data, err := json.MarshalIndent(m.connections, "", " ")
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
// getDefaultTestConnections returns a set of deterministic test connections
func getDefaultTestConnections() []Connection {
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
return []Connection{
{
TS: baseTime,
PID: 1234,
Process: "nginx",
User: "www-data",
UID: 33,
Proto: "tcp",
IPVersion: "IPv4",
State: "LISTEN",
Laddr: "0.0.0.0",
Lport: 80,
Raddr: "*",
Rport: 0,
Interface: "eth0",
RxBytes: 0,
TxBytes: 0,
RttMs: 0,
Mark: "0x0",
Namespace: "init",
Inode: 12345,
},
{
TS: baseTime.Add(time.Second),
PID: 1234,
Process: "nginx",
User: "www-data",
UID: 33,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "10.0.0.1",
Lport: 80,
Raddr: "203.0.113.10",
Rport: 52344,
Interface: "eth0",
RxBytes: 10240,
TxBytes: 2048,
RttMs: 1.7,
Mark: "0x0",
Namespace: "init",
Inode: 12346,
},
{
TS: baseTime.Add(2 * time.Second),
PID: 5678,
Process: "postgres",
User: "postgres",
UID: 26,
Proto: "tcp",
IPVersion: "IPv4",
State: "LISTEN",
Laddr: "127.0.0.1",
Lport: 5432,
Raddr: "*",
Rport: 0,
Interface: "lo",
RxBytes: 0,
TxBytes: 0,
RttMs: 0,
Mark: "0x0",
Namespace: "init",
Inode: 12347,
},
{
TS: baseTime.Add(3 * time.Second),
PID: 5678,
Process: "postgres",
User: "postgres",
UID: 26,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "127.0.0.1",
Lport: 5432,
Raddr: "127.0.0.1",
Rport: 45678,
Interface: "lo",
RxBytes: 8192,
TxBytes: 4096,
RttMs: 0.1,
Mark: "0x0",
Namespace: "init",
Inode: 12348,
},
{
TS: baseTime.Add(4 * time.Second),
PID: 9999,
Process: "dns-server",
User: "named",
UID: 25,
Proto: "udp",
IPVersion: "IPv4",
State: "CONNECTED",
Laddr: "0.0.0.0",
Lport: 53,
Raddr: "*",
Rport: 0,
Interface: "eth0",
RxBytes: 1024,
TxBytes: 512,
RttMs: 0,
Mark: "0x0",
Namespace: "init",
Inode: 12349,
},
{
TS: baseTime.Add(5 * time.Second),
PID: 1111,
Process: "ssh",
User: "root",
UID: 0,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "192.168.1.100",
Lport: 22,
Raddr: "192.168.1.200",
Rport: 54321,
Interface: "eth0",
RxBytes: 2048,
TxBytes: 1024,
RttMs: 2.3,
Mark: "0x0",
Namespace: "init",
Inode: 12350,
},
{
TS: baseTime.Add(6 * time.Second),
PID: 2222,
Process: "app-server",
User: "app",
UID: 1000,
Proto: "unix",
IPVersion: "",
State: "CONNECTED",
Laddr: "/tmp/app.sock",
Lport: 0,
Raddr: "",
Rport: 0,
Interface: "unix",
RxBytes: 512,
TxBytes: 256,
RttMs: 0,
Mark: "",
Namespace: "init",
Inode: 12351,
},
}
}
// ConnectionBuilder provides a fluent interface for building test connections
type ConnectionBuilder struct {
conn Connection
}
// NewConnectionBuilder creates a new connection builder with sensible defaults
func NewConnectionBuilder() *ConnectionBuilder {
return &ConnectionBuilder{
conn: Connection{
TS: time.Now(),
PID: 1000,
Process: "test-process",
User: "test-user",
UID: 1000,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "127.0.0.1",
Lport: 8080,
Raddr: "127.0.0.1",
Rport: 9090,
Interface: "lo",
RxBytes: 1024,
TxBytes: 512,
RttMs: 1.0,
Mark: "0x0",
Namespace: "init",
Inode: 99999,
},
}
}
// WithPID sets the PID
func (b *ConnectionBuilder) WithPID(pid int) *ConnectionBuilder {
b.conn.PID = pid
return b
}
// WithProcess sets the process name
func (b *ConnectionBuilder) WithProcess(process string) *ConnectionBuilder {
b.conn.Process = process
return b
}
// WithProto sets the protocol
func (b *ConnectionBuilder) WithProto(proto string) *ConnectionBuilder {
b.conn.Proto = proto
return b
}
// WithState sets the connection state
func (b *ConnectionBuilder) WithState(state string) *ConnectionBuilder {
b.conn.State = state
return b
}
// WithLocalAddr sets the local address and port
func (b *ConnectionBuilder) WithLocalAddr(addr string, port int) *ConnectionBuilder {
b.conn.Laddr = addr
b.conn.Lport = port
return b
}
// WithRemoteAddr sets the remote address and port
func (b *ConnectionBuilder) WithRemoteAddr(addr string, port int) *ConnectionBuilder {
b.conn.Raddr = addr
b.conn.Rport = port
return b
}
// WithInterface sets the network interface
func (b *ConnectionBuilder) WithInterface(iface string) *ConnectionBuilder {
b.conn.Interface = iface
return b
}
// WithBytes sets the rx and tx byte counts
func (b *ConnectionBuilder) WithBytes(rx, tx int64) *ConnectionBuilder {
b.conn.RxBytes = rx
b.conn.TxBytes = tx
return b
}
// Build returns the constructed connection
func (b *ConnectionBuilder) Build() Connection {
return b.conn
}
// TestFixture provides test scenarios for different use cases
type TestFixture struct {
Name string
Description string
Connections []Connection
}
// GetTestFixtures returns predefined test fixtures for different scenarios
func GetTestFixtures() []TestFixture {
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
return []TestFixture{
{
Name: "empty",
Description: "No connections",
Connections: []Connection{},
},
{
Name: "single-tcp",
Description: "Single TCP connection",
Connections: []Connection{
NewConnectionBuilder().
WithPID(1234).
WithProcess("test-app").
WithProto("tcp").
WithState("ESTABLISHED").
WithLocalAddr("127.0.0.1", 8080).
WithRemoteAddr("127.0.0.1", 9090).
Build(),
},
},
{
Name: "mixed-protocols",
Description: "Mix of TCP, UDP, and Unix sockets",
Connections: []Connection{
{
TS: baseTime,
PID: 1,
Process: "tcp-server",
Proto: "tcp",
State: "LISTEN",
Laddr: "0.0.0.0",
Lport: 80,
Interface: "eth0",
},
{
TS: baseTime.Add(time.Second),
PID: 2,
Process: "udp-server",
Proto: "udp",
State: "CONNECTED",
Laddr: "0.0.0.0",
Lport: 53,
Interface: "eth0",
},
{
TS: baseTime.Add(2 * time.Second),
PID: 3,
Process: "unix-app",
Proto: "unix",
State: "CONNECTED",
Laddr: "/tmp/test.sock",
Interface: "unix",
},
},
},
{
Name: "high-volume",
Description: "Large number of connections for performance testing",
Connections: generateHighVolumeConnections(1000),
},
}
}
// generateHighVolumeConnections creates a large number of test connections
func generateHighVolumeConnections(count int) []Connection {
connections := make([]Connection, count)
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
for i := 0; i < count; i++ {
connections[i] = NewConnectionBuilder().
WithPID(1000 + i).
WithProcess("worker-" + string(rune('a'+i%26))).
WithProto([]string{"tcp", "udp"}[i%2]).
WithState([]string{"ESTABLISHED", "LISTEN", "TIME_WAIT"}[i%3]).
WithLocalAddr("127.0.0.1", 8000+i%1000).
WithRemoteAddr("10.0.0."+string(rune('1'+i%10)), 9000+i%1000).
Build()
connections[i].TS = baseTime.Add(time.Duration(i) * time.Millisecond)
}
return connections
}

159
internal/collector/query.go Normal file
View File

@@ -0,0 +1,159 @@
package collector
// Query combines filtering, sorting, and limiting into a single operation
type Query struct {
Filter FilterOptions
Sort SortOptions
Limit int
}
// NewQuery creates a query with sensible defaults
func NewQuery() *Query {
return &Query{
Filter: FilterOptions{},
Sort: SortOptions{Field: SortByLport, Direction: SortAsc},
Limit: 0,
}
}
// WithFilter sets the filter options
func (q *Query) WithFilter(f FilterOptions) *Query {
q.Filter = f
return q
}
// WithSort sets the sort options
func (q *Query) WithSort(s SortOptions) *Query {
q.Sort = s
return q
}
// WithSortString parses and sets sort options from a string like "pid:desc"
func (q *Query) WithSortString(s string) *Query {
q.Sort = ParseSortOptions(s)
return q
}
// WithLimit sets the maximum number of results
func (q *Query) WithLimit(n int) *Query {
q.Limit = n
return q
}
// Proto filters by protocol
func (q *Query) Proto(proto string) *Query {
q.Filter.Proto = proto
return q
}
// State filters by connection state
func (q *Query) State(state string) *Query {
q.Filter.State = state
return q
}
// Process filters by process name (substring match)
func (q *Query) Process(proc string) *Query {
q.Filter.Proc = proc
return q
}
// PID filters by process ID
func (q *Query) PID(pid int) *Query {
q.Filter.Pid = pid
return q
}
// LocalPort filters by local port
func (q *Query) LocalPort(port int) *Query {
q.Filter.Lport = port
return q
}
// RemotePort filters by remote port
func (q *Query) RemotePort(port int) *Query {
q.Filter.Rport = port
return q
}
// IPv4Only filters to only IPv4 connections
func (q *Query) IPv4Only() *Query {
q.Filter.IPv4 = true
q.Filter.IPv6 = false
return q
}
// IPv6Only filters to only IPv6 connections
func (q *Query) IPv6Only() *Query {
q.Filter.IPv4 = false
q.Filter.IPv6 = true
return q
}
// Listening filters to only listening sockets
func (q *Query) Listening() *Query {
q.Filter.State = "LISTEN"
return q
}
// Established filters to only established connections
func (q *Query) Established() *Query {
q.Filter.State = "ESTABLISHED"
return q
}
// Contains filters by substring in process, local addr, or remote addr
func (q *Query) Contains(s string) *Query {
q.Filter.Contains = s
return q
}
// Execute runs the query and returns results
func (q *Query) Execute() ([]Connection, error) {
conns, err := GetConnections()
if err != nil {
return nil, err
}
return q.Apply(conns), nil
}
// Apply applies the query to a slice of connections
func (q *Query) Apply(conns []Connection) []Connection {
result := FilterConnections(conns, q.Filter)
SortConnections(result, q.Sort)
if q.Limit > 0 && len(result) > q.Limit {
result = result[:q.Limit]
}
return result
}
// common pre-built queries
// ListeningTCP returns a query for TCP listeners
func ListeningTCP() *Query {
return NewQuery().Proto("tcp").Listening()
}
// ListeningAll returns a query for all listeners
func ListeningAll() *Query {
return NewQuery().Listening()
}
// EstablishedTCP returns a query for established TCP connections
func EstablishedTCP() *Query {
return NewQuery().Proto("tcp").Established()
}
// ByProcess returns a query filtered by process name
func ByProcess(name string) *Query {
return NewQuery().Process(name)
}
// ByPort returns a query filtered by local port
func ByPort(port int) *Query {
return NewQuery().LocalPort(port)
}

View File

@@ -0,0 +1,165 @@
package collector
import (
"testing"
)
func TestQueryBuilder(t *testing.T) {
t.Run("fluent builder pattern", func(t *testing.T) {
q := NewQuery().
Proto("tcp").
State("LISTEN").
WithLimit(10)
if q.Filter.Proto != "tcp" {
t.Errorf("expected proto tcp, got %s", q.Filter.Proto)
}
if q.Filter.State != "LISTEN" {
t.Errorf("expected state LISTEN, got %s", q.Filter.State)
}
if q.Limit != 10 {
t.Errorf("expected limit 10, got %d", q.Limit)
}
})
t.Run("convenience methods", func(t *testing.T) {
q := NewQuery().Listening()
if q.Filter.State != "LISTEN" {
t.Errorf("Listening() should set state to LISTEN")
}
q = NewQuery().Established()
if q.Filter.State != "ESTABLISHED" {
t.Errorf("Established() should set state to ESTABLISHED")
}
q = NewQuery().IPv4Only()
if q.Filter.IPv4 != true || q.Filter.IPv6 != false {
t.Error("IPv4Only() should set IPv4=true, IPv6=false")
}
q = NewQuery().IPv6Only()
if q.Filter.IPv4 != false || q.Filter.IPv6 != true {
t.Error("IPv6Only() should set IPv4=false, IPv6=true")
}
})
t.Run("sort options", func(t *testing.T) {
q := NewQuery().WithSortString("pid:desc")
if q.Sort.Field != SortByPID {
t.Errorf("expected sort by pid, got %v", q.Sort.Field)
}
if q.Sort.Direction != SortDesc {
t.Errorf("expected sort desc, got %v", q.Sort.Direction)
}
})
}
func TestQueryApply(t *testing.T) {
conns := []Connection{
{PID: 1, Process: "nginx", Proto: "tcp", State: "LISTEN", Lport: 80},
{PID: 2, Process: "nginx", Proto: "tcp", State: "ESTABLISHED", Lport: 80},
{PID: 3, Process: "sshd", Proto: "tcp", State: "LISTEN", Lport: 22},
{PID: 4, Process: "postgres", Proto: "tcp", State: "LISTEN", Lport: 5432},
{PID: 5, Process: "dnsmasq", Proto: "udp", State: "", Lport: 53},
}
t.Run("filter by state", func(t *testing.T) {
q := NewQuery().Listening()
result := q.Apply(conns)
if len(result) != 3 {
t.Errorf("expected 3 listening connections, got %d", len(result))
}
})
t.Run("filter by proto", func(t *testing.T) {
q := NewQuery().Proto("udp")
result := q.Apply(conns)
if len(result) != 1 {
t.Errorf("expected 1 udp connection, got %d", len(result))
}
if result[0].Process != "dnsmasq" {
t.Errorf("expected dnsmasq, got %s", result[0].Process)
}
})
t.Run("filter and sort", func(t *testing.T) {
q := NewQuery().Listening().WithSortString("lport:asc")
result := q.Apply(conns)
if len(result) != 3 {
t.Fatalf("expected 3, got %d", len(result))
}
if result[0].Lport != 22 {
t.Errorf("expected port 22 first, got %d", result[0].Lport)
}
})
t.Run("filter sort and limit", func(t *testing.T) {
q := NewQuery().Proto("tcp").WithSortString("lport:asc").WithLimit(2)
result := q.Apply(conns)
if len(result) != 2 {
t.Errorf("expected 2 (limit), got %d", len(result))
}
})
t.Run("process filter substring", func(t *testing.T) {
q := NewQuery().Process("nginx")
result := q.Apply(conns)
if len(result) != 2 {
t.Errorf("expected 2 nginx connections, got %d", len(result))
}
})
t.Run("contains filter", func(t *testing.T) {
q := NewQuery().Contains("post")
result := q.Apply(conns)
if len(result) != 1 || result[0].Process != "postgres" {
t.Errorf("expected postgres, got %v", result)
}
})
}
func TestPrebuiltQueries(t *testing.T) {
t.Run("ListeningTCP", func(t *testing.T) {
q := ListeningTCP()
if q.Filter.Proto != "tcp" || q.Filter.State != "LISTEN" {
t.Error("ListeningTCP should filter tcp + LISTEN")
}
})
t.Run("ListeningAll", func(t *testing.T) {
q := ListeningAll()
if q.Filter.State != "LISTEN" {
t.Error("ListeningAll should filter LISTEN state")
}
})
t.Run("EstablishedTCP", func(t *testing.T) {
q := EstablishedTCP()
if q.Filter.Proto != "tcp" || q.Filter.State != "ESTABLISHED" {
t.Error("EstablishedTCP should filter tcp + ESTABLISHED")
}
})
t.Run("ByProcess", func(t *testing.T) {
q := ByProcess("nginx")
if q.Filter.Proc != "nginx" {
t.Error("ByProcess should set process filter")
}
})
t.Run("ByPort", func(t *testing.T) {
q := ByPort(8080)
if q.Filter.Lport != 8080 {
t.Error("ByPort should set lport filter")
}
})
}

131
internal/collector/sort.go Normal file
View File

@@ -0,0 +1,131 @@
package collector
import (
"sort"
"strings"
)
// SortField represents a field to sort by
type SortField string
const (
SortByPID SortField = "pid"
SortByProcess SortField = "process"
SortByUser SortField = "user"
SortByProto SortField = "proto"
SortByState SortField = "state"
SortByLaddr SortField = "laddr"
SortByLport SortField = "lport"
SortByRaddr SortField = "raddr"
SortByRport SortField = "rport"
SortByInterface SortField = "if"
SortByRxBytes SortField = "rx_bytes"
SortByTxBytes SortField = "tx_bytes"
SortByRttMs SortField = "rtt_ms"
SortByTimestamp SortField = "ts"
)
// SortDirection represents ascending or descending order
type SortDirection int
const (
SortAsc SortDirection = iota
SortDesc
)
// SortOptions configures how connections are sorted
type SortOptions struct {
Field SortField
Direction SortDirection
}
// ParseSortOptions parses a sort string like "pid:desc" or "lport"
func ParseSortOptions(s string) SortOptions {
if s == "" {
return SortOptions{Field: SortByLport, Direction: SortAsc}
}
parts := strings.SplitN(s, ":", 2)
field := SortField(strings.ToLower(parts[0]))
direction := SortAsc
if len(parts) > 1 && strings.ToLower(parts[1]) == "desc" {
direction = SortDesc
}
return SortOptions{Field: field, Direction: direction}
}
// SortConnections sorts a slice of connections in place
func SortConnections(conns []Connection, opts SortOptions) {
if len(conns) < 2 {
return
}
sort.SliceStable(conns, func(i, j int) bool {
less := compareConnections(conns[i], conns[j], opts.Field)
if opts.Direction == SortDesc {
return !less
}
return less
})
}
func compareConnections(a, b Connection, field SortField) bool {
switch field {
case SortByPID:
return a.PID < b.PID
case SortByProcess:
return strings.ToLower(a.Process) < strings.ToLower(b.Process)
case SortByUser:
return strings.ToLower(a.User) < strings.ToLower(b.User)
case SortByProto:
return a.Proto < b.Proto
case SortByState:
return stateOrder(a.State) < stateOrder(b.State)
case SortByLaddr:
return a.Laddr < b.Laddr
case SortByLport:
return a.Lport < b.Lport
case SortByRaddr:
return a.Raddr < b.Raddr
case SortByRport:
return a.Rport < b.Rport
case SortByInterface:
return a.Interface < b.Interface
case SortByRxBytes:
return a.RxBytes < b.RxBytes
case SortByTxBytes:
return a.TxBytes < b.TxBytes
case SortByRttMs:
return a.RttMs < b.RttMs
case SortByTimestamp:
return a.TS.Before(b.TS)
default:
return a.Lport < b.Lport
}
}
// stateOrder returns a numeric order for connection states
// puts LISTEN first, then ESTABLISHED, then others
func stateOrder(state string) int {
order := map[string]int{
"LISTEN": 0,
"ESTABLISHED": 1,
"SYN_SENT": 2,
"SYN_RECV": 3,
"FIN_WAIT1": 4,
"FIN_WAIT2": 5,
"TIME_WAIT": 6,
"CLOSE_WAIT": 7,
"LAST_ACK": 8,
"CLOSING": 9,
"CLOSED": 10,
}
if o, exists := order[strings.ToUpper(state)]; exists {
return o
}
return 99
}

View File

@@ -0,0 +1,130 @@
package collector
import (
"testing"
"time"
)
func TestSortConnections(t *testing.T) {
conns := []Connection{
{PID: 3, Process: "nginx", Lport: 80, State: "ESTABLISHED"},
{PID: 1, Process: "sshd", Lport: 22, State: "LISTEN"},
{PID: 2, Process: "postgres", Lport: 5432, State: "LISTEN"},
}
t.Run("sort by PID ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByPID, Direction: SortAsc})
if c[0].PID != 1 || c[1].PID != 2 || c[2].PID != 3 {
t.Errorf("expected PIDs [1,2,3], got [%d,%d,%d]", c[0].PID, c[1].PID, c[2].PID)
}
})
t.Run("sort by PID descending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByPID, Direction: SortDesc})
if c[0].PID != 3 || c[1].PID != 2 || c[2].PID != 1 {
t.Errorf("expected PIDs [3,2,1], got [%d,%d,%d]", c[0].PID, c[1].PID, c[2].PID)
}
})
t.Run("sort by port ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByLport, Direction: SortAsc})
if c[0].Lport != 22 || c[1].Lport != 80 || c[2].Lport != 5432 {
t.Errorf("expected ports [22,80,5432], got [%d,%d,%d]", c[0].Lport, c[1].Lport, c[2].Lport)
}
})
t.Run("sort by state puts LISTEN first", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByState, Direction: SortAsc})
if c[0].State != "LISTEN" || c[1].State != "LISTEN" || c[2].State != "ESTABLISHED" {
t.Errorf("expected LISTEN states first, got [%s,%s,%s]", c[0].State, c[1].State, c[2].State)
}
})
t.Run("sort by process case insensitive", func(t *testing.T) {
c := []Connection{
{Process: "Nginx"},
{Process: "apache"},
{Process: "SSHD"},
}
SortConnections(c, SortOptions{Field: SortByProcess, Direction: SortAsc})
if c[0].Process != "apache" {
t.Errorf("expected apache first, got %s", c[0].Process)
}
})
}
func TestParseSortOptions(t *testing.T) {
tests := []struct {
input string
wantField SortField
wantDir SortDirection
}{
{"pid", SortByPID, SortAsc},
{"pid:asc", SortByPID, SortAsc},
{"pid:desc", SortByPID, SortDesc},
{"lport", SortByLport, SortAsc},
{"LPORT:DESC", SortByLport, SortDesc},
{"", SortByLport, SortAsc}, // default
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
opts := ParseSortOptions(tt.input)
if opts.Field != tt.wantField {
t.Errorf("field: got %v, want %v", opts.Field, tt.wantField)
}
if opts.Direction != tt.wantDir {
t.Errorf("direction: got %v, want %v", opts.Direction, tt.wantDir)
}
})
}
}
func TestStateOrder(t *testing.T) {
if stateOrder("LISTEN") >= stateOrder("ESTABLISHED") {
t.Error("LISTEN should come before ESTABLISHED")
}
if stateOrder("ESTABLISHED") >= stateOrder("TIME_WAIT") {
t.Error("ESTABLISHED should come before TIME_WAIT")
}
if stateOrder("UNKNOWN") != 99 {
t.Error("unknown states should return 99")
}
}
func TestSortByTimestamp(t *testing.T) {
now := time.Now()
conns := []Connection{
{TS: now.Add(2 * time.Second)},
{TS: now},
{TS: now.Add(1 * time.Second)},
}
SortConnections(conns, SortOptions{Field: SortByTimestamp, Direction: SortAsc})
if !conns[0].TS.Equal(now) {
t.Error("oldest timestamp should be first")
}
if !conns[2].TS.Equal(now.Add(2 * time.Second)) {
t.Error("newest timestamp should be last")
}
}

View File

@@ -0,0 +1,25 @@
package collector
import "time"
type Connection struct {
TS time.Time `json:"ts"`
PID int `json:"pid"`
Process string `json:"process"`
User string `json:"user"`
UID int `json:"uid"`
Proto string `json:"proto"`
IPVersion string `json:"ipversion"`
State string `json:"state"`
Laddr string `json:"laddr"`
Lport int `json:"lport"`
Raddr string `json:"raddr"`
Rport int `json:"rport"`
Interface string `json:"interface"`
RxBytes int64 `json:"rx_bytes"`
TxBytes int64 `json:"tx_bytes"`
RttMs float64 `json:"rtt_ms"`
Mark string `json:"mark"`
Namespace string `json:"namespace"`
Inode int64 `json:"inode"`
}