initial commit

This commit is contained in:
Karol Broda
2025-12-16 22:42:49 +01:00
commit 371f4d13a6
61 changed files with 6872 additions and 0 deletions

View File

@@ -0,0 +1,506 @@
package collector
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
// Collector interface defines methods for collecting connection data
type Collector interface {
GetConnections() ([]Connection, error)
}
// DefaultCollector implements the Collector interface using /proc
type DefaultCollector struct{}
// Global collector instance (can be overridden for testing)
var globalCollector Collector = &DefaultCollector{}
// SetCollector sets the global collector instance
func SetCollector(collector Collector) {
globalCollector = collector
}
// GetCollector returns the current global collector instance
func GetCollector() Collector {
return globalCollector
}
// GetConnections fetches all network connections using the global collector
func GetConnections() ([]Connection, error) {
return globalCollector.GetConnections()
}
// GetConnections fetches all network connections by parsing /proc files.
func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
if runtime.GOOS != "linux" {
return nil, fmt.Errorf("proc-based collector only supports Linux")
}
// Build map of inode -> process info by scanning /proc
inodeMap, err := buildInodeToProcessMap()
if err != nil {
return nil, fmt.Errorf("failed to build inode map: %w", err)
}
var connections []Connection
// Parse TCP connections
tcpConns, err := parseProcNet("/proc/net/tcp", "tcp", 4, inodeMap)
if err == nil {
connections = append(connections, tcpConns...)
}
tcpConns6, err := parseProcNet("/proc/net/tcp6", "tcp6", 6, inodeMap)
if err == nil {
connections = append(connections, tcpConns6...)
}
// Parse UDP connections
udpConns, err := parseProcNet("/proc/net/udp", "udp", 4, inodeMap)
if err == nil {
connections = append(connections, udpConns...)
}
udpConns6, err := parseProcNet("/proc/net/udp6", "udp6", 6, inodeMap)
if err == nil {
connections = append(connections, udpConns6...)
}
return connections, nil
}
// GetAllConnections returns both network and Unix domain socket connections
func GetAllConnections() ([]Connection, error) {
// Get network connections
networkConns, err := GetConnections()
if err != nil {
return nil, err
}
// Get Unix sockets (only on Linux)
if runtime.GOOS == "linux" {
unixConns, err := GetUnixSockets()
if err == nil {
networkConns = append(networkConns, unixConns...)
}
}
return networkConns, nil
}
func FilterConnections(conns []Connection, filters FilterOptions) []Connection {
if filters.IsEmpty() {
return conns
}
filtered := make([]Connection, 0)
for _, conn := range conns {
if filters.Matches(conn) {
filtered = append(filtered, conn)
}
}
return filtered
}
// processInfo holds information about a process
type processInfo struct {
pid int
command string
uid int
user string
}
// buildInodeToProcessMap scans /proc to map socket inodes to processes
func buildInodeToProcessMap() (map[int64]*processInfo, error) {
inodeMap := make(map[int64]*processInfo)
procDir, err := os.Open("/proc")
if err != nil {
return nil, err
}
defer procDir.Close()
entries, err := procDir.Readdir(-1)
if err != nil {
return nil, err
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// check if directory name is a number (pid)
pidStr := entry.Name()
pid, err := strconv.Atoi(pidStr)
if err != nil {
continue
}
// get process info
procInfo, err := getProcessInfo(pid)
if err != nil {
continue
}
// scan /proc/[pid]/fd/ for socket file descriptors
fdDir := filepath.Join("/proc", pidStr, "fd")
fdEntries, err := os.ReadDir(fdDir)
if err != nil {
continue
}
for _, fdEntry := range fdEntries {
fdPath := filepath.Join(fdDir, fdEntry.Name())
link, err := os.Readlink(fdPath)
if err != nil {
continue
}
// socket inodes look like: socket:[12345]
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
inodeStr := link[8 : len(link)-1]
inode, err := strconv.ParseInt(inodeStr, 10, 64)
if err != nil {
continue
}
inodeMap[inode] = procInfo
}
}
}
return inodeMap, nil
}
// getProcessInfo reads process information from /proc/[pid]/
func getProcessInfo(pid int) (*processInfo, error) {
info := &processInfo{pid: pid}
// prefer /proc/[pid]/comm as it's always just the command name
commPath := filepath.Join("/proc", strconv.Itoa(pid), "comm")
commData, err := os.ReadFile(commPath)
if err == nil && len(commData) > 0 {
info.command = strings.TrimSpace(string(commData))
}
// if comm is not available, try cmdline
if info.command == "" {
cmdlinePath := filepath.Join("/proc", strconv.Itoa(pid), "cmdline")
cmdlineData, err := os.ReadFile(cmdlinePath)
if err != nil {
return nil, err
}
// cmdline is null-separated, take first part
if len(cmdlineData) > 0 {
parts := bytes.Split(cmdlineData, []byte{0})
if len(parts) > 0 && len(parts[0]) > 0 {
fullPath := string(parts[0])
// extract basename from full path
baseName := filepath.Base(fullPath)
// if basename contains spaces (single-string cmdline), take first word
if strings.Contains(baseName, " ") {
baseName = strings.Fields(baseName)[0]
}
info.command = baseName
}
}
}
// read UID from /proc/[pid]/status
statusPath := filepath.Join("/proc", strconv.Itoa(pid), "status")
statusFile, err := os.Open(statusPath)
if err != nil {
return info, nil
}
defer statusFile.Close()
scanner := bufio.NewScanner(statusFile)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "Uid:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
uid, err := strconv.Atoi(fields[1])
if err == nil {
info.uid = uid
// get username from uid
u, err := user.LookupId(strconv.Itoa(uid))
if err == nil {
info.user = u.Username
} else {
info.user = strconv.Itoa(uid)
}
}
}
break
}
}
return info, nil
}
// parseProcNet parses a /proc/net/tcp or /proc/net/udp file
func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*processInfo) ([]Connection, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var connections []Connection
scanner := bufio.NewScanner(file)
// skip header
scanner.Scan()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
fields := strings.Fields(line)
if len(fields) < 10 {
continue
}
// parse local address and port
localAddr, localPort, err := parseHexAddr(fields[1])
if err != nil {
continue
}
// parse remote address and port
remoteAddr, remotePort, err := parseHexAddr(fields[2])
if err != nil {
continue
}
// parse state (field 3)
stateHex := fields[3]
state := parseState(stateHex, proto)
// parse inode (field 9)
inode, _ := strconv.ParseInt(fields[9], 10, 64)
conn := Connection{
TS: time.Now(),
Proto: proto,
IPVersion: fmt.Sprintf("IPv%d", ipVersion),
State: state,
Laddr: localAddr,
Lport: localPort,
Raddr: remoteAddr,
Rport: remotePort,
Inode: inode,
}
// add process info if available
if procInfo, exists := inodeMap[inode]; exists {
conn.PID = procInfo.pid
conn.Process = procInfo.command
conn.UID = procInfo.uid
conn.User = procInfo.user
}
// determine interface
conn.Interface = guessNetworkInterface(localAddr, nil)
connections = append(connections, conn)
}
return connections, scanner.Err()
}
// parseState converts hex state value to string
func parseState(hexState, proto string) string {
state, err := strconv.ParseInt(hexState, 16, 32)
if err != nil {
return ""
}
// TCP states
tcpStates := map[int64]string{
0x01: "ESTABLISHED",
0x02: "SYN_SENT",
0x03: "SYN_RECV",
0x04: "FIN_WAIT1",
0x05: "FIN_WAIT2",
0x06: "TIME_WAIT",
0x07: "CLOSE",
0x08: "CLOSE_WAIT",
0x09: "LAST_ACK",
0x0A: "LISTEN",
0x0B: "CLOSING",
}
if strings.HasPrefix(proto, "tcp") {
if s, exists := tcpStates[state]; exists {
return s
}
} else {
// UDP doesn't have states in the same way
if state == 0x07 {
return "CLOSE"
}
return ""
}
return ""
}
// parseHexAddr parses hex-encoded address:port from /proc/net files
func parseHexAddr(hexAddr string) (string, int, error) {
parts := strings.Split(hexAddr, ":")
if len(parts) != 2 {
return "", 0, fmt.Errorf("invalid address format")
}
hexIP := parts[0]
// parse hex port
port, err := strconv.ParseInt(parts[1], 16, 32)
if err != nil {
return "", 0, err
}
if len(hexIP) == 8 {
// IPv4 (stored in little-endian)
ip1, _ := strconv.ParseInt(hexIP[6:8], 16, 32)
ip2, _ := strconv.ParseInt(hexIP[4:6], 16, 32)
ip3, _ := strconv.ParseInt(hexIP[2:4], 16, 32)
ip4, _ := strconv.ParseInt(hexIP[0:2], 16, 32)
addr := fmt.Sprintf("%d.%d.%d.%d", ip1, ip2, ip3, ip4)
// handle wildcard address
if addr == "0.0.0.0" {
addr = "*"
}
return addr, int(port), nil
} else if len(hexIP) == 32 {
// IPv6 (stored in little-endian per 32-bit word)
var ipv6Parts []string
for i := 0; i < 32; i += 8 {
word := hexIP[i : i+8]
// reverse byte order within each 32-bit word
p1 := word[6:8] + word[4:6] + word[2:4] + word[0:2]
ipv6Parts = append(ipv6Parts, p1)
}
// convert to standard IPv6 notation
fullAddr := strings.Join(ipv6Parts, "")
var formatted []string
for i := 0; i < len(fullAddr); i += 4 {
formatted = append(formatted, fullAddr[i:i+4])
}
addr := strings.Join(formatted, ":")
// simplify IPv6 address
addr = simplifyIPv6(addr)
// handle wildcard address
if addr == "::" || addr == "0:0:0:0:0:0:0:0" {
addr = "*"
}
return addr, int(port), nil
}
return "", 0, fmt.Errorf("unsupported address format")
}
// simplifyIPv6 simplifies IPv6 address notation
func simplifyIPv6(addr string) string {
// remove leading zeros from each group
parts := strings.Split(addr, ":")
for i, part := range parts {
// convert to int and back to remove leading zeros
val, err := strconv.ParseInt(part, 16, 64)
if err == nil {
parts[i] = strconv.FormatInt(val, 16)
}
}
return strings.Join(parts, ":")
}
func guessNetworkInterface(addr string, interfaces map[string]string) string {
// Simple heuristic - try to match common interface patterns
if addr == "127.0.0.1" || addr == "::1" {
return "lo"
}
// Check if it's a private network address
ip := net.ParseIP(addr)
if ip != nil {
if ip.IsLoopback() {
return "lo"
}
// More sophisticated interface detection would require routing table analysis
// For now, return a placeholder
if ip.To4() != nil {
return "eth0" // Common default for IPv4
} else {
return "eth0" // Common default for IPv6
}
}
return ""
}
// Add Unix socket support
func GetUnixSockets() ([]Connection, error) {
connections := []Connection{}
// Parse /proc/net/unix for Unix domain sockets
file, err := os.Open("/proc/net/unix")
if err != nil {
return connections, nil // silently fail on non-Linux systems
}
defer file.Close()
scanner := bufio.NewScanner(file)
// Skip header
scanner.Scan()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
fields := strings.Fields(line)
if len(fields) < 7 {
continue
}
// Parse Unix socket information
inode, _ := strconv.ParseInt(fields[6], 10, 64)
path := ""
if len(fields) > 7 {
path = fields[7]
}
conn := Connection{
TS: time.Now(),
Proto: "unix",
Laddr: path,
Raddr: "",
State: "CONNECTED", // Simplified
Inode: inode,
Interface: "unix",
}
connections = append(connections, conn)
}
return connections, nil
}

View File

@@ -0,0 +1,16 @@
package collector
import (
"testing"
)
func TestGetConnections(t *testing.T) {
// integration test to verify /proc parsing works
conns, err := GetConnections()
if err != nil {
t.Fatalf("GetConnections() returned an error: %v", err)
}
// connections are dynamic, so just verify function succeeded
t.Logf("Successfully got %d connections", len(conns))
}

View File

@@ -0,0 +1,129 @@
package collector
import (
"strings"
"time"
)
type FilterOptions struct {
Proto string
State string
Pid int
Proc string
Lport int
Rport int
User string
UID int
Laddr string
Raddr string
Contains string
IPv4 bool
IPv6 bool
Interface string
Mark string
Namespace string
Inode int64
Since time.Time
SinceRel time.Duration
}
func (f *FilterOptions) IsEmpty() bool {
return f.Proto == "" && f.State == "" && f.Pid == 0 && f.Proc == "" &&
f.Lport == 0 && f.Rport == 0 && f.User == "" && f.UID == 0 &&
f.Laddr == "" && f.Raddr == "" && f.Contains == "" &&
f.Interface == "" && f.Mark == "" && f.Namespace == "" && f.Inode == 0 &&
f.Since.IsZero() && f.SinceRel == 0 && !f.IPv4 && !f.IPv6
}
func (f *FilterOptions) Matches(c Connection) bool {
if f.Proto != "" && !strings.EqualFold(c.Proto, f.Proto) {
return false
}
if f.State != "" && !strings.EqualFold(c.State, f.State) {
return false
}
if f.Pid != 0 && c.PID != f.Pid {
return false
}
if f.Proc != "" && !containsIgnoreCase(c.Process, f.Proc) {
return false
}
if f.Lport != 0 && c.Lport != f.Lport {
return false
}
if f.Rport != 0 && c.Rport != f.Rport {
return false
}
if f.User != "" && !strings.EqualFold(c.User, f.User) {
return false
}
if f.UID != 0 && c.UID != f.UID {
return false
}
if f.Laddr != "" && !strings.EqualFold(c.Laddr, f.Laddr) {
return false
}
if f.Raddr != "" && !strings.EqualFold(c.Raddr, f.Raddr) {
return false
}
if f.Contains != "" && !matchesContains(c, f.Contains) {
return false
}
if f.IPv4 && c.IPVersion != "IPv4" {
return false
}
if f.IPv6 && c.IPVersion != "IPv6" {
return false
}
if f.Interface != "" && !strings.EqualFold(c.Interface, f.Interface) {
return false
}
if f.Mark != "" && !strings.EqualFold(c.Mark, f.Mark) {
return false
}
if f.Namespace != "" && !strings.EqualFold(c.Namespace, f.Namespace) {
return false
}
if f.Inode != 0 && c.Inode != f.Inode {
return false
}
if !f.Since.IsZero() && c.TS.Before(f.Since) {
return false
}
if f.SinceRel != 0 {
threshold := time.Now().Add(-f.SinceRel)
if c.TS.Before(threshold) {
return false
}
}
return true
}
func containsIgnoreCase(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}
func matchesContains(c Connection, query string) bool {
q := strings.ToLower(query)
return containsIgnoreCase(c.Process, q) ||
containsIgnoreCase(c.Laddr, q) ||
containsIgnoreCase(c.Raddr, q) ||
containsIgnoreCase(c.User, q)
}
// ParseTimeFilter parses a time filter string (RFC3339 or relative like "5s", "2m", "1h")
func ParseTimeFilter(timeStr string) (time.Time, time.Duration, error) {
// Try parsing as RFC3339 first
if t, err := time.Parse(time.RFC3339, timeStr); err == nil {
return t, 0, nil
}
// Try parsing as relative duration
if dur, err := time.ParseDuration(timeStr); err == nil {
return time.Time{}, dur, nil
}
return time.Time{}, 0, nil // Invalid format, but don't error
}

View File

@@ -0,0 +1,44 @@
package collector
import (
"testing"
)
func TestFilterConnections(t *testing.T) {
conns := []Connection{
{PID: 1, Process: "proc1", User: "user1", Proto: "tcp", State: "ESTABLISHED", Laddr: "1.1.1.1", Lport: 80, Raddr: "2.2.2.2", Rport: 1234},
{PID: 2, Process: "proc2", User: "user2", Proto: "udp", State: "LISTEN", Laddr: "3.3.3.3", Lport: 53, Raddr: "*", Rport: 0},
{PID: 3, Process: "proc1_extra", User: "user1", Proto: "tcp", State: "ESTABLISHED", Laddr: "4.4.4.4", Lport: 443, Raddr: "5.5.5.5", Rport: 5678},
}
testCases := []struct {
name string
filters FilterOptions
expected int
}{
{"No filters", FilterOptions{}, 3},
{"Filter by proto tcp", FilterOptions{Proto: "tcp"}, 2},
{"Filter by proto udp", FilterOptions{Proto: "udp"}, 1},
{"Filter by state", FilterOptions{State: "ESTABLISHED"}, 2},
{"Filter by pid", FilterOptions{Pid: 2}, 1},
{"Filter by proc", FilterOptions{Proc: "proc1"}, 2},
{"Filter by lport", FilterOptions{Lport: 80}, 1},
{"Filter by rport", FilterOptions{Rport: 1234}, 1},
{"Filter by user", FilterOptions{User: "user1"}, 2},
{"Filter by laddr", FilterOptions{Laddr: "1.1.1.1"}, 1},
{"Filter by raddr", FilterOptions{Raddr: "5.5.5.5"}, 1},
{"Filter by contains proc", FilterOptions{Contains: "proc2"}, 1},
{"Filter by contains addr", FilterOptions{Contains: "3.3.3.3"}, 1},
{"Combined filter", FilterOptions{Proto: "tcp", State: "ESTABLISHED"}, 2},
{"No match", FilterOptions{Proto: "tcp", State: "LISTEN"}, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
filtered := FilterConnections(conns, tc.filters)
if len(filtered) != tc.expected {
t.Errorf("Expected %d connections, but got %d", tc.expected, len(filtered))
}
})
}
}

403
internal/collector/mock.go Normal file
View File

@@ -0,0 +1,403 @@
package collector
import (
"encoding/json"
"os"
"time"
)
// MockCollector provides deterministic test data for testing
// It implements the Collector interface
type MockCollector struct {
connections []Connection
}
// NewMockCollector creates a new mock collector with default test data
func NewMockCollector() *MockCollector {
return &MockCollector{
connections: getDefaultTestConnections(),
}
}
// NewMockCollectorFromFile creates a mock collector from a JSON fixture file
func NewMockCollectorFromFile(filename string) (*MockCollector, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
var connections []Connection
if err := json.Unmarshal(data, &connections); err != nil {
return nil, err
}
return &MockCollector{
connections: connections,
}, nil
}
// GetConnections returns the mock connections
func (m *MockCollector) GetConnections() ([]Connection, error) {
// Return a copy to avoid mutation
result := make([]Connection, len(m.connections))
copy(result, m.connections)
return result, nil
}
// AddConnection adds a connection to the mock data
func (m *MockCollector) AddConnection(conn Connection) {
m.connections = append(m.connections, conn)
}
// SetConnections replaces all connections with the provided slice
func (m *MockCollector) SetConnections(connections []Connection) {
m.connections = make([]Connection, len(connections))
copy(m.connections, connections)
}
// SaveToFile saves the current connections to a JSON file
func (m *MockCollector) SaveToFile(filename string) error {
data, err := json.MarshalIndent(m.connections, "", " ")
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
// getDefaultTestConnections returns a set of deterministic test connections
func getDefaultTestConnections() []Connection {
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
return []Connection{
{
TS: baseTime,
PID: 1234,
Process: "nginx",
User: "www-data",
UID: 33,
Proto: "tcp",
IPVersion: "IPv4",
State: "LISTEN",
Laddr: "0.0.0.0",
Lport: 80,
Raddr: "*",
Rport: 0,
Interface: "eth0",
RxBytes: 0,
TxBytes: 0,
RttMs: 0,
Mark: "0x0",
Namespace: "init",
Inode: 12345,
},
{
TS: baseTime.Add(time.Second),
PID: 1234,
Process: "nginx",
User: "www-data",
UID: 33,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "10.0.0.1",
Lport: 80,
Raddr: "203.0.113.10",
Rport: 52344,
Interface: "eth0",
RxBytes: 10240,
TxBytes: 2048,
RttMs: 1.7,
Mark: "0x0",
Namespace: "init",
Inode: 12346,
},
{
TS: baseTime.Add(2 * time.Second),
PID: 5678,
Process: "postgres",
User: "postgres",
UID: 26,
Proto: "tcp",
IPVersion: "IPv4",
State: "LISTEN",
Laddr: "127.0.0.1",
Lport: 5432,
Raddr: "*",
Rport: 0,
Interface: "lo",
RxBytes: 0,
TxBytes: 0,
RttMs: 0,
Mark: "0x0",
Namespace: "init",
Inode: 12347,
},
{
TS: baseTime.Add(3 * time.Second),
PID: 5678,
Process: "postgres",
User: "postgres",
UID: 26,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "127.0.0.1",
Lport: 5432,
Raddr: "127.0.0.1",
Rport: 45678,
Interface: "lo",
RxBytes: 8192,
TxBytes: 4096,
RttMs: 0.1,
Mark: "0x0",
Namespace: "init",
Inode: 12348,
},
{
TS: baseTime.Add(4 * time.Second),
PID: 9999,
Process: "dns-server",
User: "named",
UID: 25,
Proto: "udp",
IPVersion: "IPv4",
State: "CONNECTED",
Laddr: "0.0.0.0",
Lport: 53,
Raddr: "*",
Rport: 0,
Interface: "eth0",
RxBytes: 1024,
TxBytes: 512,
RttMs: 0,
Mark: "0x0",
Namespace: "init",
Inode: 12349,
},
{
TS: baseTime.Add(5 * time.Second),
PID: 1111,
Process: "ssh",
User: "root",
UID: 0,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "192.168.1.100",
Lport: 22,
Raddr: "192.168.1.200",
Rport: 54321,
Interface: "eth0",
RxBytes: 2048,
TxBytes: 1024,
RttMs: 2.3,
Mark: "0x0",
Namespace: "init",
Inode: 12350,
},
{
TS: baseTime.Add(6 * time.Second),
PID: 2222,
Process: "app-server",
User: "app",
UID: 1000,
Proto: "unix",
IPVersion: "",
State: "CONNECTED",
Laddr: "/tmp/app.sock",
Lport: 0,
Raddr: "",
Rport: 0,
Interface: "unix",
RxBytes: 512,
TxBytes: 256,
RttMs: 0,
Mark: "",
Namespace: "init",
Inode: 12351,
},
}
}
// ConnectionBuilder provides a fluent interface for building test connections
type ConnectionBuilder struct {
conn Connection
}
// NewConnectionBuilder creates a new connection builder with sensible defaults
func NewConnectionBuilder() *ConnectionBuilder {
return &ConnectionBuilder{
conn: Connection{
TS: time.Now(),
PID: 1000,
Process: "test-process",
User: "test-user",
UID: 1000,
Proto: "tcp",
IPVersion: "IPv4",
State: "ESTABLISHED",
Laddr: "127.0.0.1",
Lport: 8080,
Raddr: "127.0.0.1",
Rport: 9090,
Interface: "lo",
RxBytes: 1024,
TxBytes: 512,
RttMs: 1.0,
Mark: "0x0",
Namespace: "init",
Inode: 99999,
},
}
}
// WithPID sets the PID
func (b *ConnectionBuilder) WithPID(pid int) *ConnectionBuilder {
b.conn.PID = pid
return b
}
// WithProcess sets the process name
func (b *ConnectionBuilder) WithProcess(process string) *ConnectionBuilder {
b.conn.Process = process
return b
}
// WithProto sets the protocol
func (b *ConnectionBuilder) WithProto(proto string) *ConnectionBuilder {
b.conn.Proto = proto
return b
}
// WithState sets the connection state
func (b *ConnectionBuilder) WithState(state string) *ConnectionBuilder {
b.conn.State = state
return b
}
// WithLocalAddr sets the local address and port
func (b *ConnectionBuilder) WithLocalAddr(addr string, port int) *ConnectionBuilder {
b.conn.Laddr = addr
b.conn.Lport = port
return b
}
// WithRemoteAddr sets the remote address and port
func (b *ConnectionBuilder) WithRemoteAddr(addr string, port int) *ConnectionBuilder {
b.conn.Raddr = addr
b.conn.Rport = port
return b
}
// WithInterface sets the network interface
func (b *ConnectionBuilder) WithInterface(iface string) *ConnectionBuilder {
b.conn.Interface = iface
return b
}
// WithBytes sets the rx and tx byte counts
func (b *ConnectionBuilder) WithBytes(rx, tx int64) *ConnectionBuilder {
b.conn.RxBytes = rx
b.conn.TxBytes = tx
return b
}
// Build returns the constructed connection
func (b *ConnectionBuilder) Build() Connection {
return b.conn
}
// TestFixture provides test scenarios for different use cases
type TestFixture struct {
Name string
Description string
Connections []Connection
}
// GetTestFixtures returns predefined test fixtures for different scenarios
func GetTestFixtures() []TestFixture {
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
return []TestFixture{
{
Name: "empty",
Description: "No connections",
Connections: []Connection{},
},
{
Name: "single-tcp",
Description: "Single TCP connection",
Connections: []Connection{
NewConnectionBuilder().
WithPID(1234).
WithProcess("test-app").
WithProto("tcp").
WithState("ESTABLISHED").
WithLocalAddr("127.0.0.1", 8080).
WithRemoteAddr("127.0.0.1", 9090).
Build(),
},
},
{
Name: "mixed-protocols",
Description: "Mix of TCP, UDP, and Unix sockets",
Connections: []Connection{
{
TS: baseTime,
PID: 1,
Process: "tcp-server",
Proto: "tcp",
State: "LISTEN",
Laddr: "0.0.0.0",
Lport: 80,
Interface: "eth0",
},
{
TS: baseTime.Add(time.Second),
PID: 2,
Process: "udp-server",
Proto: "udp",
State: "CONNECTED",
Laddr: "0.0.0.0",
Lport: 53,
Interface: "eth0",
},
{
TS: baseTime.Add(2 * time.Second),
PID: 3,
Process: "unix-app",
Proto: "unix",
State: "CONNECTED",
Laddr: "/tmp/test.sock",
Interface: "unix",
},
},
},
{
Name: "high-volume",
Description: "Large number of connections for performance testing",
Connections: generateHighVolumeConnections(1000),
},
}
}
// generateHighVolumeConnections creates a large number of test connections
func generateHighVolumeConnections(count int) []Connection {
connections := make([]Connection, count)
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
for i := 0; i < count; i++ {
connections[i] = NewConnectionBuilder().
WithPID(1000 + i).
WithProcess("worker-" + string(rune('a'+i%26))).
WithProto([]string{"tcp", "udp"}[i%2]).
WithState([]string{"ESTABLISHED", "LISTEN", "TIME_WAIT"}[i%3]).
WithLocalAddr("127.0.0.1", 8000+i%1000).
WithRemoteAddr("10.0.0."+string(rune('1'+i%10)), 9000+i%1000).
Build()
connections[i].TS = baseTime.Add(time.Duration(i) * time.Millisecond)
}
return connections
}

159
internal/collector/query.go Normal file
View File

@@ -0,0 +1,159 @@
package collector
// Query combines filtering, sorting, and limiting into a single operation
type Query struct {
Filter FilterOptions
Sort SortOptions
Limit int
}
// NewQuery creates a query with sensible defaults
func NewQuery() *Query {
return &Query{
Filter: FilterOptions{},
Sort: SortOptions{Field: SortByLport, Direction: SortAsc},
Limit: 0,
}
}
// WithFilter sets the filter options
func (q *Query) WithFilter(f FilterOptions) *Query {
q.Filter = f
return q
}
// WithSort sets the sort options
func (q *Query) WithSort(s SortOptions) *Query {
q.Sort = s
return q
}
// WithSortString parses and sets sort options from a string like "pid:desc"
func (q *Query) WithSortString(s string) *Query {
q.Sort = ParseSortOptions(s)
return q
}
// WithLimit sets the maximum number of results
func (q *Query) WithLimit(n int) *Query {
q.Limit = n
return q
}
// Proto filters by protocol
func (q *Query) Proto(proto string) *Query {
q.Filter.Proto = proto
return q
}
// State filters by connection state
func (q *Query) State(state string) *Query {
q.Filter.State = state
return q
}
// Process filters by process name (substring match)
func (q *Query) Process(proc string) *Query {
q.Filter.Proc = proc
return q
}
// PID filters by process ID
func (q *Query) PID(pid int) *Query {
q.Filter.Pid = pid
return q
}
// LocalPort filters by local port
func (q *Query) LocalPort(port int) *Query {
q.Filter.Lport = port
return q
}
// RemotePort filters by remote port
func (q *Query) RemotePort(port int) *Query {
q.Filter.Rport = port
return q
}
// IPv4Only filters to only IPv4 connections
func (q *Query) IPv4Only() *Query {
q.Filter.IPv4 = true
q.Filter.IPv6 = false
return q
}
// IPv6Only filters to only IPv6 connections
func (q *Query) IPv6Only() *Query {
q.Filter.IPv4 = false
q.Filter.IPv6 = true
return q
}
// Listening filters to only listening sockets
func (q *Query) Listening() *Query {
q.Filter.State = "LISTEN"
return q
}
// Established filters to only established connections
func (q *Query) Established() *Query {
q.Filter.State = "ESTABLISHED"
return q
}
// Contains filters by substring in process, local addr, or remote addr
func (q *Query) Contains(s string) *Query {
q.Filter.Contains = s
return q
}
// Execute runs the query and returns results
func (q *Query) Execute() ([]Connection, error) {
conns, err := GetConnections()
if err != nil {
return nil, err
}
return q.Apply(conns), nil
}
// Apply applies the query to a slice of connections
func (q *Query) Apply(conns []Connection) []Connection {
result := FilterConnections(conns, q.Filter)
SortConnections(result, q.Sort)
if q.Limit > 0 && len(result) > q.Limit {
result = result[:q.Limit]
}
return result
}
// common pre-built queries
// ListeningTCP returns a query for TCP listeners
func ListeningTCP() *Query {
return NewQuery().Proto("tcp").Listening()
}
// ListeningAll returns a query for all listeners
func ListeningAll() *Query {
return NewQuery().Listening()
}
// EstablishedTCP returns a query for established TCP connections
func EstablishedTCP() *Query {
return NewQuery().Proto("tcp").Established()
}
// ByProcess returns a query filtered by process name
func ByProcess(name string) *Query {
return NewQuery().Process(name)
}
// ByPort returns a query filtered by local port
func ByPort(port int) *Query {
return NewQuery().LocalPort(port)
}

View File

@@ -0,0 +1,165 @@
package collector
import (
"testing"
)
func TestQueryBuilder(t *testing.T) {
t.Run("fluent builder pattern", func(t *testing.T) {
q := NewQuery().
Proto("tcp").
State("LISTEN").
WithLimit(10)
if q.Filter.Proto != "tcp" {
t.Errorf("expected proto tcp, got %s", q.Filter.Proto)
}
if q.Filter.State != "LISTEN" {
t.Errorf("expected state LISTEN, got %s", q.Filter.State)
}
if q.Limit != 10 {
t.Errorf("expected limit 10, got %d", q.Limit)
}
})
t.Run("convenience methods", func(t *testing.T) {
q := NewQuery().Listening()
if q.Filter.State != "LISTEN" {
t.Errorf("Listening() should set state to LISTEN")
}
q = NewQuery().Established()
if q.Filter.State != "ESTABLISHED" {
t.Errorf("Established() should set state to ESTABLISHED")
}
q = NewQuery().IPv4Only()
if q.Filter.IPv4 != true || q.Filter.IPv6 != false {
t.Error("IPv4Only() should set IPv4=true, IPv6=false")
}
q = NewQuery().IPv6Only()
if q.Filter.IPv4 != false || q.Filter.IPv6 != true {
t.Error("IPv6Only() should set IPv4=false, IPv6=true")
}
})
t.Run("sort options", func(t *testing.T) {
q := NewQuery().WithSortString("pid:desc")
if q.Sort.Field != SortByPID {
t.Errorf("expected sort by pid, got %v", q.Sort.Field)
}
if q.Sort.Direction != SortDesc {
t.Errorf("expected sort desc, got %v", q.Sort.Direction)
}
})
}
func TestQueryApply(t *testing.T) {
conns := []Connection{
{PID: 1, Process: "nginx", Proto: "tcp", State: "LISTEN", Lport: 80},
{PID: 2, Process: "nginx", Proto: "tcp", State: "ESTABLISHED", Lport: 80},
{PID: 3, Process: "sshd", Proto: "tcp", State: "LISTEN", Lport: 22},
{PID: 4, Process: "postgres", Proto: "tcp", State: "LISTEN", Lport: 5432},
{PID: 5, Process: "dnsmasq", Proto: "udp", State: "", Lport: 53},
}
t.Run("filter by state", func(t *testing.T) {
q := NewQuery().Listening()
result := q.Apply(conns)
if len(result) != 3 {
t.Errorf("expected 3 listening connections, got %d", len(result))
}
})
t.Run("filter by proto", func(t *testing.T) {
q := NewQuery().Proto("udp")
result := q.Apply(conns)
if len(result) != 1 {
t.Errorf("expected 1 udp connection, got %d", len(result))
}
if result[0].Process != "dnsmasq" {
t.Errorf("expected dnsmasq, got %s", result[0].Process)
}
})
t.Run("filter and sort", func(t *testing.T) {
q := NewQuery().Listening().WithSortString("lport:asc")
result := q.Apply(conns)
if len(result) != 3 {
t.Fatalf("expected 3, got %d", len(result))
}
if result[0].Lport != 22 {
t.Errorf("expected port 22 first, got %d", result[0].Lport)
}
})
t.Run("filter sort and limit", func(t *testing.T) {
q := NewQuery().Proto("tcp").WithSortString("lport:asc").WithLimit(2)
result := q.Apply(conns)
if len(result) != 2 {
t.Errorf("expected 2 (limit), got %d", len(result))
}
})
t.Run("process filter substring", func(t *testing.T) {
q := NewQuery().Process("nginx")
result := q.Apply(conns)
if len(result) != 2 {
t.Errorf("expected 2 nginx connections, got %d", len(result))
}
})
t.Run("contains filter", func(t *testing.T) {
q := NewQuery().Contains("post")
result := q.Apply(conns)
if len(result) != 1 || result[0].Process != "postgres" {
t.Errorf("expected postgres, got %v", result)
}
})
}
func TestPrebuiltQueries(t *testing.T) {
t.Run("ListeningTCP", func(t *testing.T) {
q := ListeningTCP()
if q.Filter.Proto != "tcp" || q.Filter.State != "LISTEN" {
t.Error("ListeningTCP should filter tcp + LISTEN")
}
})
t.Run("ListeningAll", func(t *testing.T) {
q := ListeningAll()
if q.Filter.State != "LISTEN" {
t.Error("ListeningAll should filter LISTEN state")
}
})
t.Run("EstablishedTCP", func(t *testing.T) {
q := EstablishedTCP()
if q.Filter.Proto != "tcp" || q.Filter.State != "ESTABLISHED" {
t.Error("EstablishedTCP should filter tcp + ESTABLISHED")
}
})
t.Run("ByProcess", func(t *testing.T) {
q := ByProcess("nginx")
if q.Filter.Proc != "nginx" {
t.Error("ByProcess should set process filter")
}
})
t.Run("ByPort", func(t *testing.T) {
q := ByPort(8080)
if q.Filter.Lport != 8080 {
t.Error("ByPort should set lport filter")
}
})
}

131
internal/collector/sort.go Normal file
View File

@@ -0,0 +1,131 @@
package collector
import (
"sort"
"strings"
)
// SortField represents a field to sort by
type SortField string
const (
SortByPID SortField = "pid"
SortByProcess SortField = "process"
SortByUser SortField = "user"
SortByProto SortField = "proto"
SortByState SortField = "state"
SortByLaddr SortField = "laddr"
SortByLport SortField = "lport"
SortByRaddr SortField = "raddr"
SortByRport SortField = "rport"
SortByInterface SortField = "if"
SortByRxBytes SortField = "rx_bytes"
SortByTxBytes SortField = "tx_bytes"
SortByRttMs SortField = "rtt_ms"
SortByTimestamp SortField = "ts"
)
// SortDirection represents ascending or descending order
type SortDirection int
const (
SortAsc SortDirection = iota
SortDesc
)
// SortOptions configures how connections are sorted
type SortOptions struct {
Field SortField
Direction SortDirection
}
// ParseSortOptions parses a sort string like "pid:desc" or "lport"
func ParseSortOptions(s string) SortOptions {
if s == "" {
return SortOptions{Field: SortByLport, Direction: SortAsc}
}
parts := strings.SplitN(s, ":", 2)
field := SortField(strings.ToLower(parts[0]))
direction := SortAsc
if len(parts) > 1 && strings.ToLower(parts[1]) == "desc" {
direction = SortDesc
}
return SortOptions{Field: field, Direction: direction}
}
// SortConnections sorts a slice of connections in place
func SortConnections(conns []Connection, opts SortOptions) {
if len(conns) < 2 {
return
}
sort.SliceStable(conns, func(i, j int) bool {
less := compareConnections(conns[i], conns[j], opts.Field)
if opts.Direction == SortDesc {
return !less
}
return less
})
}
func compareConnections(a, b Connection, field SortField) bool {
switch field {
case SortByPID:
return a.PID < b.PID
case SortByProcess:
return strings.ToLower(a.Process) < strings.ToLower(b.Process)
case SortByUser:
return strings.ToLower(a.User) < strings.ToLower(b.User)
case SortByProto:
return a.Proto < b.Proto
case SortByState:
return stateOrder(a.State) < stateOrder(b.State)
case SortByLaddr:
return a.Laddr < b.Laddr
case SortByLport:
return a.Lport < b.Lport
case SortByRaddr:
return a.Raddr < b.Raddr
case SortByRport:
return a.Rport < b.Rport
case SortByInterface:
return a.Interface < b.Interface
case SortByRxBytes:
return a.RxBytes < b.RxBytes
case SortByTxBytes:
return a.TxBytes < b.TxBytes
case SortByRttMs:
return a.RttMs < b.RttMs
case SortByTimestamp:
return a.TS.Before(b.TS)
default:
return a.Lport < b.Lport
}
}
// stateOrder returns a numeric order for connection states
// puts LISTEN first, then ESTABLISHED, then others
func stateOrder(state string) int {
order := map[string]int{
"LISTEN": 0,
"ESTABLISHED": 1,
"SYN_SENT": 2,
"SYN_RECV": 3,
"FIN_WAIT1": 4,
"FIN_WAIT2": 5,
"TIME_WAIT": 6,
"CLOSE_WAIT": 7,
"LAST_ACK": 8,
"CLOSING": 9,
"CLOSED": 10,
}
if o, exists := order[strings.ToUpper(state)]; exists {
return o
}
return 99
}

View File

@@ -0,0 +1,130 @@
package collector
import (
"testing"
"time"
)
func TestSortConnections(t *testing.T) {
conns := []Connection{
{PID: 3, Process: "nginx", Lport: 80, State: "ESTABLISHED"},
{PID: 1, Process: "sshd", Lport: 22, State: "LISTEN"},
{PID: 2, Process: "postgres", Lport: 5432, State: "LISTEN"},
}
t.Run("sort by PID ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByPID, Direction: SortAsc})
if c[0].PID != 1 || c[1].PID != 2 || c[2].PID != 3 {
t.Errorf("expected PIDs [1,2,3], got [%d,%d,%d]", c[0].PID, c[1].PID, c[2].PID)
}
})
t.Run("sort by PID descending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByPID, Direction: SortDesc})
if c[0].PID != 3 || c[1].PID != 2 || c[2].PID != 1 {
t.Errorf("expected PIDs [3,2,1], got [%d,%d,%d]", c[0].PID, c[1].PID, c[2].PID)
}
})
t.Run("sort by port ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByLport, Direction: SortAsc})
if c[0].Lport != 22 || c[1].Lport != 80 || c[2].Lport != 5432 {
t.Errorf("expected ports [22,80,5432], got [%d,%d,%d]", c[0].Lport, c[1].Lport, c[2].Lport)
}
})
t.Run("sort by state puts LISTEN first", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByState, Direction: SortAsc})
if c[0].State != "LISTEN" || c[1].State != "LISTEN" || c[2].State != "ESTABLISHED" {
t.Errorf("expected LISTEN states first, got [%s,%s,%s]", c[0].State, c[1].State, c[2].State)
}
})
t.Run("sort by process case insensitive", func(t *testing.T) {
c := []Connection{
{Process: "Nginx"},
{Process: "apache"},
{Process: "SSHD"},
}
SortConnections(c, SortOptions{Field: SortByProcess, Direction: SortAsc})
if c[0].Process != "apache" {
t.Errorf("expected apache first, got %s", c[0].Process)
}
})
}
func TestParseSortOptions(t *testing.T) {
tests := []struct {
input string
wantField SortField
wantDir SortDirection
}{
{"pid", SortByPID, SortAsc},
{"pid:asc", SortByPID, SortAsc},
{"pid:desc", SortByPID, SortDesc},
{"lport", SortByLport, SortAsc},
{"LPORT:DESC", SortByLport, SortDesc},
{"", SortByLport, SortAsc}, // default
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
opts := ParseSortOptions(tt.input)
if opts.Field != tt.wantField {
t.Errorf("field: got %v, want %v", opts.Field, tt.wantField)
}
if opts.Direction != tt.wantDir {
t.Errorf("direction: got %v, want %v", opts.Direction, tt.wantDir)
}
})
}
}
func TestStateOrder(t *testing.T) {
if stateOrder("LISTEN") >= stateOrder("ESTABLISHED") {
t.Error("LISTEN should come before ESTABLISHED")
}
if stateOrder("ESTABLISHED") >= stateOrder("TIME_WAIT") {
t.Error("ESTABLISHED should come before TIME_WAIT")
}
if stateOrder("UNKNOWN") != 99 {
t.Error("unknown states should return 99")
}
}
func TestSortByTimestamp(t *testing.T) {
now := time.Now()
conns := []Connection{
{TS: now.Add(2 * time.Second)},
{TS: now},
{TS: now.Add(1 * time.Second)},
}
SortConnections(conns, SortOptions{Field: SortByTimestamp, Direction: SortAsc})
if !conns[0].TS.Equal(now) {
t.Error("oldest timestamp should be first")
}
if !conns[2].TS.Equal(now.Add(2 * time.Second)) {
t.Error("newest timestamp should be last")
}
}

View File

@@ -0,0 +1,25 @@
package collector
import "time"
type Connection struct {
TS time.Time `json:"ts"`
PID int `json:"pid"`
Process string `json:"process"`
User string `json:"user"`
UID int `json:"uid"`
Proto string `json:"proto"`
IPVersion string `json:"ipversion"`
State string `json:"state"`
Laddr string `json:"laddr"`
Lport int `json:"lport"`
Raddr string `json:"raddr"`
Rport int `json:"rport"`
Interface string `json:"interface"`
RxBytes int64 `json:"rx_bytes"`
TxBytes int64 `json:"tx_bytes"`
RttMs float64 `json:"rtt_ms"`
Mark string `json:"mark"`
Namespace string `json:"namespace"`
Inode int64 `json:"inode"`
}

60
internal/color/color.go Normal file
View File

@@ -0,0 +1,60 @@
package color
import (
"os"
"strings"
"github.com/fatih/color"
)
var (
Header = color.New(color.FgGreen, color.Bold)
Bold = color.New(color.Bold)
Faint = color.New(color.Faint)
TCP = color.New(color.FgCyan)
UDP = color.New(color.FgMagenta)
LISTEN = color.New(color.FgYellow)
ESTABLISHED = color.New(color.FgGreen)
Default = color.New(color.FgWhite)
)
func Init(mode string) {
switch mode {
case "always":
color.NoColor = false
case "never":
color.NoColor = true
case "auto":
if os.Getenv("NO_COLOR") != "" || os.Getenv("TERM") == "dumb" {
color.NoColor = true
} else {
color.NoColor = false
}
}
}
func IsColorDisabled() bool {
return color.NoColor
}
func GetProtoColor(proto string) *color.Color {
switch strings.ToLower(proto) {
case "tcp":
return TCP
case "udp":
return UDP
default:
return Default
}
}
func GetStateColor(state string) *color.Color {
switch strings.ToUpper(state) {
case "LISTEN", "LISTENING":
return LISTEN
case "ESTABLISHED":
return ESTABLISHED
default:
return Default
}
}

View File

@@ -0,0 +1,46 @@
package color
import (
"os"
"testing"
"github.com/fatih/color"
)
func TestInit(t *testing.T) {
testCases := []struct {
name string
mode string
noColor string
term string
expected bool
}{
{"Always", "always", "", "", false},
{"Never", "never", "", "", true},
{"Auto no env", "auto", "", "xterm-256color", false},
{"Auto with NO_COLOR", "auto", "1", "xterm-256color", true},
{"Auto with TERM=dumb", "auto", "", "dumb", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Save original env vars
origNoColor := os.Getenv("NO_COLOR")
origTerm := os.Getenv("TERM")
// Set test env vars
os.Setenv("NO_COLOR", tc.noColor)
os.Setenv("TERM", tc.term)
Init(tc.mode)
if color.NoColor != tc.expected {
t.Errorf("Expected color.NoColor to be %v, but got %v", tc.expected, color.NoColor)
}
// Restore original env vars
os.Setenv("NO_COLOR", origNoColor)
os.Setenv("TERM", origTerm)
})
}
}

203
internal/config/config.go Normal file
View File

@@ -0,0 +1,203 @@
package config
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/spf13/viper"
)
// Config represents the application configuration
type Config struct {
Defaults DefaultConfig `mapstructure:"defaults"`
}
// DefaultConfig contains default values for CLI options
type DefaultConfig struct {
Interval string `mapstructure:"interval"`
Numeric bool `mapstructure:"numeric"`
Fields []string `mapstructure:"fields"`
Theme string `mapstructure:"theme"`
Units string `mapstructure:"units"`
Color string `mapstructure:"color"`
Resolve bool `mapstructure:"resolve"`
IPv4 bool `mapstructure:"ipv4"`
IPv6 bool `mapstructure:"ipv6"`
NoHeaders bool `mapstructure:"no_headers"`
OutputFormat string `mapstructure:"output_format"`
SortBy string `mapstructure:"sort_by"`
}
var globalConfig *Config
// Load loads configuration from file and environment variables
func Load() (*Config, error) {
if globalConfig != nil {
return globalConfig, nil
}
v := viper.New()
// set config name and file type (auto-detect based on extension)
v.SetConfigName("snitch")
// don't set config type - let viper auto-detect based on file extension
// this allows both .toml and .yaml files to work
v.AddConfigPath("$HOME/.config/snitch")
v.AddConfigPath("$HOME/.snitch")
v.AddConfigPath("/etc/snitch")
// Environment variables
v.SetEnvPrefix("SNITCH")
v.AutomaticEnv()
// environment variable bindings for readme-documented variables
_ = v.BindEnv("config", "SNITCH_CONFIG")
_ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE")
_ = v.BindEnv("defaults.theme", "SNITCH_THEME")
_ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR")
// Set defaults
setDefaults(v)
// Handle SNITCH_CONFIG environment variable for custom config path
if configPath := os.Getenv("SNITCH_CONFIG"); configPath != "" {
v.SetConfigFile(configPath)
}
// Try to read config file
if err := v.ReadInConfig(); err != nil {
// It's OK if config file doesn't exist
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("error reading config file: %w", err)
}
}
// Handle special environment variables
handleSpecialEnvVars(v)
// Unmarshal into config struct
config := &Config{}
if err := v.Unmarshal(config); err != nil {
return nil, fmt.Errorf("error unmarshaling config: %w", err)
}
globalConfig = config
return config, nil
}
func setDefaults(v *viper.Viper) {
// Set default values matching the README specification
v.SetDefault("defaults.interval", "1s")
v.SetDefault("defaults.numeric", false)
v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"})
v.SetDefault("defaults.theme", "auto")
v.SetDefault("defaults.units", "auto")
v.SetDefault("defaults.color", "auto")
v.SetDefault("defaults.resolve", true)
v.SetDefault("defaults.ipv4", false)
v.SetDefault("defaults.ipv6", false)
v.SetDefault("defaults.no_headers", false)
v.SetDefault("defaults.output_format", "table")
v.SetDefault("defaults.sort_by", "")
}
func handleSpecialEnvVars(v *viper.Viper) {
// Handle SNITCH_NO_COLOR - if set to "1", disable color
if os.Getenv("SNITCH_NO_COLOR") == "1" {
v.Set("defaults.color", "never")
}
// Handle SNITCH_RESOLVE - if set to "0", disable resolution
if os.Getenv("SNITCH_RESOLVE") == "0" {
v.Set("defaults.resolve", false)
v.Set("defaults.numeric", true)
}
}
// Get returns the global configuration, loading it if necessary
func Get() *Config {
if globalConfig == nil {
config, err := Load()
if err != nil {
// Return default config on error
return &Config{
Defaults: DefaultConfig{
Interval: "1s",
Numeric: false,
Fields: []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"},
Theme: "auto",
Units: "auto",
Color: "auto",
Resolve: true,
IPv4: false,
IPv6: false,
NoHeaders: false,
OutputFormat: "table",
SortBy: "",
},
}
}
return config
}
return globalConfig
}
// GetInterval returns the configured interval as a duration
func (c *Config) GetInterval() time.Duration {
if duration, err := time.ParseDuration(c.Defaults.Interval); err == nil {
return duration
}
return time.Second // default fallback
}
// CreateExampleConfig creates an example configuration file
func CreateExampleConfig(path string) error {
exampleConfig := `# snitch configuration file
# See https://github.com/you/snitch for full documentation
[defaults]
# Default refresh interval for watch/stats/trace commands
interval = "1s"
# Disable name/service resolution by default
numeric = false
# Default fields to display (comma-separated list)
fields = ["pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"]
# Default theme for TUI (dark, light, mono, auto)
theme = "auto"
# Default units for byte display (auto, si, iec)
units = "auto"
# Default color mode (auto, always, never)
color = "auto"
# Enable name resolution by default
resolve = true
# Filter options
ipv4 = false
ipv6 = false
# Output options
no_headers = false
output_format = "table"
sort_by = ""
`
// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Write config file
if err := os.WriteFile(path, []byte(exampleConfig), 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,191 @@
package resolver
import (
"context"
"net"
"strconv"
"sync"
"time"
)
// Resolver handles DNS and service name resolution with caching and timeouts
type Resolver struct {
timeout time.Duration
cache map[string]string
mutex sync.RWMutex
}
// New creates a new resolver with the specified timeout
func New(timeout time.Duration) *Resolver {
return &Resolver{
timeout: timeout,
cache: make(map[string]string),
}
}
// ResolveAddr resolves an IP address to a hostname, with caching
func (r *Resolver) ResolveAddr(addr string) string {
// Check cache first
r.mutex.RLock()
if cached, exists := r.cache[addr]; exists {
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock()
// Parse IP to validate it
ip := net.ParseIP(addr)
if ip == nil {
// Not a valid IP, return as-is
return addr
}
// Perform resolution with timeout
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
resolved := addr // fallback to original address
if err == nil && len(names) > 0 {
resolved = names[0]
// Remove trailing dot if present
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
resolved = resolved[:len(resolved)-1]
}
}
// Cache the result
r.mutex.Lock()
r.cache[addr] = resolved
r.mutex.Unlock()
return resolved
}
// ResolvePort resolves a port number to a service name
func (r *Resolver) ResolvePort(port int, proto string) string {
if port == 0 {
return "0"
}
cacheKey := strconv.Itoa(port) + "/" + proto
// Check cache first
r.mutex.RLock()
if cached, exists := r.cache[cacheKey]; exists {
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock()
// Perform resolution with timeout
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
service, err := net.DefaultResolver.LookupPort(ctx, proto, strconv.Itoa(port))
resolved := strconv.Itoa(port) // fallback to port number
if err == nil && service != 0 {
// Try to get service name
if serviceName := getServiceName(port, proto); serviceName != "" {
resolved = serviceName
}
}
// Cache the result
r.mutex.Lock()
r.cache[cacheKey] = resolved
r.mutex.Unlock()
return resolved
}
// ResolveAddrPort resolves both address and port
func (r *Resolver) ResolveAddrPort(addr string, port int, proto string) (string, string) {
resolvedAddr := r.ResolveAddr(addr)
resolvedPort := r.ResolvePort(port, proto)
return resolvedAddr, resolvedPort
}
// ClearCache clears the resolution cache
func (r *Resolver) ClearCache() {
r.mutex.Lock()
defer r.mutex.Unlock()
r.cache = make(map[string]string)
}
// GetCacheSize returns the number of cached entries
func (r *Resolver) GetCacheSize() int {
r.mutex.RLock()
defer r.mutex.RUnlock()
return len(r.cache)
}
// getServiceName returns well-known service names for common ports
func getServiceName(port int, proto string) string {
// Common services - this could be expanded or loaded from /etc/services
services := map[string]string{
"80/tcp": "http",
"443/tcp": "https",
"22/tcp": "ssh",
"21/tcp": "ftp",
"25/tcp": "smtp",
"53/tcp": "domain",
"53/udp": "domain",
"110/tcp": "pop3",
"143/tcp": "imap",
"993/tcp": "imaps",
"995/tcp": "pop3s",
"3306/tcp": "mysql",
"5432/tcp": "postgresql",
"6379/tcp": "redis",
"3389/tcp": "rdp",
"5900/tcp": "vnc",
"23/tcp": "telnet",
"69/udp": "tftp",
"123/udp": "ntp",
"161/udp": "snmp",
"514/udp": "syslog",
"67/udp": "bootps",
"68/udp": "bootpc",
}
key := strconv.Itoa(port) + "/" + proto
if service, exists := services[key]; exists {
return service
}
return ""
}
// Global resolver instance
var globalResolver *Resolver
// SetGlobalResolver sets the global resolver instance
func SetGlobalResolver(timeout time.Duration) {
globalResolver = New(timeout)
}
// GetGlobalResolver returns the global resolver instance
func GetGlobalResolver() *Resolver {
if globalResolver == nil {
globalResolver = New(200 * time.Millisecond) // Default timeout
}
return globalResolver
}
// ResolveAddr is a convenience function using the global resolver
func ResolveAddr(addr string) string {
return GetGlobalResolver().ResolveAddr(addr)
}
// ResolvePort is a convenience function using the global resolver
func ResolvePort(port int, proto string) string {
return GetGlobalResolver().ResolvePort(port, proto)
}
// ResolveAddrPort is a convenience function using the global resolver
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
}

View File

@@ -0,0 +1,215 @@
package testutil
import (
"os"
"path/filepath"
"testing"
"snitch/internal/collector"
)
// TestCollector wraps MockCollector for use in tests
type TestCollector struct {
*collector.MockCollector
}
// NewTestCollector creates a new test collector with default data
func NewTestCollector() *TestCollector {
return &TestCollector{
MockCollector: collector.NewMockCollector(),
}
}
// NewTestCollectorWithFixture creates a test collector with a specific fixture
func NewTestCollectorWithFixture(fixtureName string) *TestCollector {
fixtures := collector.GetTestFixtures()
for _, fixture := range fixtures {
if fixture.Name == fixtureName {
mock := collector.NewMockCollector()
mock.SetConnections(fixture.Connections)
return &TestCollector{MockCollector: mock}
}
}
// Fallback to default if fixture not found
return NewTestCollector()
}
// SetupTestEnvironment sets up a clean test environment
func SetupTestEnvironment(t *testing.T) (string, func()) {
// Create temporary directory for test files
tempDir, err := os.MkdirTemp("", "snitch-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
// Set test environment variables
oldConfig := os.Getenv("SNITCH_CONFIG")
oldNoColor := os.Getenv("SNITCH_NO_COLOR")
os.Setenv("SNITCH_NO_COLOR", "1") // Disable colors in tests
// Cleanup function
cleanup := func() {
os.RemoveAll(tempDir)
os.Setenv("SNITCH_CONFIG", oldConfig)
os.Setenv("SNITCH_NO_COLOR", oldNoColor)
}
return tempDir, cleanup
}
// CreateFixtureFile creates a JSON fixture file with the given connections
func CreateFixtureFile(t *testing.T, dir string, name string, connections []collector.Connection) string {
mock := collector.NewMockCollector()
mock.SetConnections(connections)
filePath := filepath.Join(dir, name+".json")
if err := mock.SaveToFile(filePath); err != nil {
t.Fatalf("Failed to create fixture file %s: %v", filePath, err)
}
return filePath
}
// LoadFixtureFile loads connections from a JSON fixture file
func LoadFixtureFile(t *testing.T, filePath string) []collector.Connection {
mock, err := collector.NewMockCollectorFromFile(filePath)
if err != nil {
t.Fatalf("Failed to load fixture file %s: %v", filePath, err)
}
connections, err := mock.GetConnections()
if err != nil {
t.Fatalf("Failed to get connections from fixture: %v", err)
}
return connections
}
// AssertConnectionsEqual compares two slices of connections for equality
func AssertConnectionsEqual(t *testing.T, expected, actual []collector.Connection) {
if len(expected) != len(actual) {
t.Errorf("Connection count mismatch: expected %d, got %d", len(expected), len(actual))
return
}
for i, exp := range expected {
act := actual[i]
// Compare key fields (timestamps may vary slightly)
if exp.PID != act.PID {
t.Errorf("Connection %d PID mismatch: expected %d, got %d", i, exp.PID, act.PID)
}
if exp.Process != act.Process {
t.Errorf("Connection %d Process mismatch: expected %s, got %s", i, exp.Process, act.Process)
}
if exp.Proto != act.Proto {
t.Errorf("Connection %d Proto mismatch: expected %s, got %s", i, exp.Proto, act.Proto)
}
if exp.State != act.State {
t.Errorf("Connection %d State mismatch: expected %s, got %s", i, exp.State, act.State)
}
if exp.Laddr != act.Laddr {
t.Errorf("Connection %d Laddr mismatch: expected %s, got %s", i, exp.Laddr, act.Laddr)
}
if exp.Lport != act.Lport {
t.Errorf("Connection %d Lport mismatch: expected %d, got %d", i, exp.Lport, act.Lport)
}
}
}
// GetTestConfig returns a test configuration with safe defaults
func GetTestConfig() map[string]interface{} {
return map[string]interface{}{
"defaults": map[string]interface{}{
"interval": "1s",
"numeric": true, // Disable resolution in tests
"fields": []string{"pid", "process", "proto", "state", "laddr", "lport"},
"theme": "mono", // Use monochrome theme in tests
"units": "auto",
"color": "never",
"resolve": false,
"ipv4": false,
"ipv6": false,
"no_headers": false,
"output_format": "table",
"sort_by": "",
},
}
}
// CaptureOutput captures stdout/stderr during test execution
type OutputCapture struct {
stdout *os.File
stderr *os.File
oldStdout *os.File
oldStderr *os.File
stdoutFile string
stderrFile string
}
// NewOutputCapture creates a new output capture
func NewOutputCapture(t *testing.T) *OutputCapture {
tempDir, err := os.MkdirTemp("", "snitch-output-*")
if err != nil {
t.Fatalf("Failed to create temp dir for output capture: %v", err)
}
stdoutFile := filepath.Join(tempDir, "stdout")
stderrFile := filepath.Join(tempDir, "stderr")
stdout, err := os.Create(stdoutFile)
if err != nil {
t.Fatalf("Failed to create stdout file: %v", err)
}
stderr, err := os.Create(stderrFile)
if err != nil {
t.Fatalf("Failed to create stderr file: %v", err)
}
return &OutputCapture{
stdout: stdout,
stderr: stderr,
oldStdout: os.Stdout,
oldStderr: os.Stderr,
stdoutFile: stdoutFile,
stderrFile: stderrFile,
}
}
// Start begins capturing output
func (oc *OutputCapture) Start() {
os.Stdout = oc.stdout
os.Stderr = oc.stderr
}
// Stop stops capturing and returns the captured output
func (oc *OutputCapture) Stop() (string, string, error) {
// Restore original stdout/stderr
os.Stdout = oc.oldStdout
os.Stderr = oc.oldStderr
// Close files
oc.stdout.Close()
oc.stderr.Close()
// Read captured content
stdoutContent, err := os.ReadFile(oc.stdoutFile)
if err != nil {
return "", "", err
}
stderrContent, err := os.ReadFile(oc.stderrFile)
if err != nil {
return "", "", err
}
// Cleanup
os.Remove(oc.stdoutFile)
os.Remove(oc.stderrFile)
os.Remove(filepath.Dir(oc.stdoutFile))
return string(stdoutContent), string(stderrContent), nil
}

247
internal/theme/theme.go Normal file
View File

@@ -0,0 +1,247 @@
package theme
import (
"strings"
"github.com/charmbracelet/lipgloss"
)
// Theme represents the visual styling for the TUI
type Theme struct {
Name string
Styles Styles
}
// Styles contains all the styling definitions
type Styles struct {
Header lipgloss.Style
Border lipgloss.Style
Selected lipgloss.Style
Watched lipgloss.Style
Normal lipgloss.Style
Error lipgloss.Style
Success lipgloss.Style
Warning lipgloss.Style
Proto ProtoStyles
State StateStyles
Footer lipgloss.Style
Background lipgloss.Style
}
// ProtoStyles contains protocol-specific colors
type ProtoStyles struct {
TCP lipgloss.Style
UDP lipgloss.Style
Unix lipgloss.Style
TCP6 lipgloss.Style
UDP6 lipgloss.Style
}
// StateStyles contains connection state-specific colors
type StateStyles struct {
Listen lipgloss.Style
Established lipgloss.Style
TimeWait lipgloss.Style
CloseWait lipgloss.Style
SynSent lipgloss.Style
SynRecv lipgloss.Style
FinWait1 lipgloss.Style
FinWait2 lipgloss.Style
Closing lipgloss.Style
LastAck lipgloss.Style
Closed lipgloss.Style
}
var (
themes map[string]*Theme
)
func init() {
themes = map[string]*Theme{
"default": createAdaptiveTheme(),
"mono": createMonoTheme(),
}
}
// GetTheme returns a theme by name, with auto-detection support
func GetTheme(name string) *Theme {
if name == "auto" {
// lipgloss handles adaptive colors, so we just return the default
return themes["default"]
}
if theme, exists := themes[name]; exists {
return theme
}
// a specific theme was requested (e.g. "dark", "light"), but we now use adaptive
// so we can just return the default theme and lipgloss will handle it
if name == "dark" || name == "light" {
return themes["default"]
}
// fallback to default
return themes["default"]
}
// ListThemes returns available theme names
func ListThemes() []string {
var names []string
for name := range themes {
names = append(names, name)
}
return names
}
// createAdaptiveTheme creates a clean, minimal theme
func createAdaptiveTheme() *Theme {
return &Theme{
Name: "default",
Styles: Styles{
Header: lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#1F2937", Dark: "#F9FAFB"}),
Watched: lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#F59E0B"}),
Border: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#D1D5DB", Dark: "#374151"}),
Selected: lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#1F2937", Dark: "#F9FAFB"}),
Normal: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#6B7280", Dark: "#9CA3AF"}),
Error: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
Success: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
Warning: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
Footer: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#9CA3AF", Dark: "#6B7280"}),
Background: lipgloss.NewStyle(),
Proto: ProtoStyles{
TCP: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
UDP: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
Unix: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#6B7280", Dark: "#9CA3AF"}),
TCP6: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
UDP6: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
},
State: StateStyles{
Listen: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
Established: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#2563EB", Dark: "#60A5FA"}),
TimeWait: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
CloseWait: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
SynSent: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
SynRecv: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
FinWait1: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
FinWait2: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
Closing: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
LastAck: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
Closed: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#9CA3AF", Dark: "#6B7280"}),
},
},
}
}
// createMonoTheme creates a monochrome theme (no colors)
func createMonoTheme() *Theme {
baseStyle := lipgloss.NewStyle()
boldStyle := lipgloss.NewStyle().Bold(true)
return &Theme{
Name: "mono",
Styles: Styles{
Header: boldStyle,
Border: baseStyle,
Selected: boldStyle,
Normal: baseStyle,
Error: boldStyle,
Success: boldStyle,
Warning: boldStyle,
Footer: baseStyle,
Background: baseStyle,
Proto: ProtoStyles{
TCP: baseStyle,
UDP: baseStyle,
Unix: baseStyle,
TCP6: baseStyle,
UDP6: baseStyle,
},
State: StateStyles{
Listen: baseStyle,
Established: baseStyle,
TimeWait: baseStyle,
CloseWait: baseStyle,
SynSent: baseStyle,
SynRecv: baseStyle,
FinWait1: baseStyle,
FinWait2: baseStyle,
Closing: baseStyle,
LastAck: baseStyle,
Closed: baseStyle,
},
},
}
}
// GetProtoStyle returns the appropriate style for a protocol
func (s *Styles) GetProtoStyle(proto string) lipgloss.Style {
switch strings.ToLower(proto) {
case "tcp":
return s.Proto.TCP
case "udp":
return s.Proto.UDP
case "unix":
return s.Proto.Unix
case "tcp6":
return s.Proto.TCP6
case "udp6":
return s.Proto.UDP6
default:
return s.Normal
}
}
// GetStateStyle returns the appropriate style for a connection state
func (s *Styles) GetStateStyle(state string) lipgloss.Style {
switch strings.ToUpper(state) {
case "LISTEN":
return s.State.Listen
case "ESTABLISHED":
return s.State.Established
case "TIME_WAIT":
return s.State.TimeWait
case "CLOSE_WAIT":
return s.State.CloseWait
case "SYN_SENT":
return s.State.SynSent
case "SYN_RECV":
return s.State.SynRecv
case "FIN_WAIT1":
return s.State.FinWait1
case "FIN_WAIT2":
return s.State.FinWait2
case "CLOSING":
return s.State.Closing
case "LAST_ACK":
return s.State.LastAck
case "CLOSED":
return s.State.Closed
default:
return s.Normal
}
}

53
internal/tui/helpers.go Normal file
View File

@@ -0,0 +1,53 @@
package tui
import (
"fmt"
"regexp"
"snitch/internal/collector"
"strings"
)
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
if max <= 2 {
return s[:max]
}
return s[:max-1] + "…"
}
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
func stripAnsi(s string) string {
return ansiRegex.ReplaceAllString(s, "")
}
func containsIgnoreCase(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}
func sortFieldLabel(f collector.SortField) string {
switch f {
case collector.SortByLport:
return "port"
case collector.SortByProcess:
return "proc"
case collector.SortByPID:
return "pid"
case collector.SortByState:
return "state"
case collector.SortByProto:
return "proto"
default:
return "port"
}
}
func formatRemote(addr string, port int) string {
if addr == "" || addr == "*" || port == 0 {
return "-"
}
return fmt.Sprintf("%s:%d", addr, port)
}

185
internal/tui/keys.go Normal file
View File

@@ -0,0 +1,185 @@
package tui
import (
"snitch/internal/collector"
tea "github.com/charmbracelet/bubbletea"
)
func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
// search mode captures all input
if m.searchActive {
return m.handleSearchKey(msg)
}
// detail view only allows closing
if m.showDetail {
return m.handleDetailKey(msg)
}
// help view only allows closing
if m.showHelp {
return m.handleHelpKey(msg)
}
return m.handleNormalKey(msg)
}
func (m model) handleSearchKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "esc":
m.searchActive = false
m.searchQuery = ""
case "enter":
m.searchActive = false
m.cursor = 0
case "backspace":
if len(m.searchQuery) > 0 {
m.searchQuery = m.searchQuery[:len(m.searchQuery)-1]
}
default:
if len(msg.String()) == 1 {
m.searchQuery += msg.String()
}
}
return m, nil
}
func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "esc", "enter", "q":
m.showDetail = false
m.selected = nil
}
return m, nil
}
func (m model) handleHelpKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "esc", "enter", "q", "?":
m.showHelp = false
}
return m, nil
}
func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "q", "ctrl+c":
return m, tea.Sequence(tea.ShowCursor, tea.Quit)
// navigation
case "j", "down":
m.moveCursor(1)
case "k", "up":
m.moveCursor(-1)
case "g":
m.cursor = 0
case "G":
visible := m.visibleConnections()
if len(visible) > 0 {
m.cursor = len(visible) - 1
}
case "ctrl+d":
m.moveCursor(m.pageSize() / 2)
case "ctrl+u":
m.moveCursor(-m.pageSize() / 2)
case "ctrl+f", "pgdown":
m.moveCursor(m.pageSize())
case "ctrl+b", "pgup":
m.moveCursor(-m.pageSize())
// filter toggles
case "t":
m.showTCP = !m.showTCP
m.clampCursor()
case "u":
m.showUDP = !m.showUDP
m.clampCursor()
case "l":
m.showListening = !m.showListening
m.clampCursor()
case "e":
m.showEstablished = !m.showEstablished
m.clampCursor()
case "o":
m.showOther = !m.showOther
m.clampCursor()
case "a":
m.showTCP = true
m.showUDP = true
m.showListening = true
m.showEstablished = true
m.showOther = true
// sorting
case "s":
m.cycleSort()
case "S":
m.sortReverse = !m.sortReverse
m.applySorting()
// search
case "/":
m.searchActive = true
m.searchQuery = ""
// actions
case "enter", " ":
visible := m.visibleConnections()
if m.cursor < len(visible) {
conn := visible[m.cursor]
m.selected = &conn
m.showDetail = true
}
case "r":
return m, m.fetchData()
case "?":
m.showHelp = true
}
return m, nil
}
func (m *model) moveCursor(delta int) {
visible := m.visibleConnections()
m.cursor += delta
if m.cursor < 0 {
m.cursor = 0
}
if m.cursor >= len(visible) {
m.cursor = len(visible) - 1
}
if m.cursor < 0 {
m.cursor = 0
}
}
func (m model) pageSize() int {
size := m.height - 6
if size < 1 {
return 10
}
return size
}
func (m *model) cycleSort() {
fields := []collector.SortField{
collector.SortByLport,
collector.SortByProcess,
collector.SortByPID,
collector.SortByState,
collector.SortByProto,
}
for i, f := range fields {
if f == m.sortField {
m.sortField = fields[(i+1)%len(fields)]
m.applySorting()
return
}
}
m.sortField = collector.SortByLport
m.applySorting()
}

35
internal/tui/messages.go Normal file
View File

@@ -0,0 +1,35 @@
package tui
import (
"snitch/internal/collector"
"time"
tea "github.com/charmbracelet/bubbletea"
)
type tickMsg time.Time
type dataMsg struct {
connections []collector.Connection
}
type errMsg struct {
err error
}
func (m model) tick() tea.Cmd {
return tea.Tick(m.interval, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m model) fetchData() tea.Cmd {
return func() tea.Msg {
conns, err := collector.GetConnections()
if err != nil {
return errMsg{err}
}
return dataMsg{connections: conns}
}
}

220
internal/tui/model.go Normal file
View File

@@ -0,0 +1,220 @@
package tui
import (
"snitch/internal/collector"
"snitch/internal/theme"
"time"
tea "github.com/charmbracelet/bubbletea"
)
type model struct {
connections []collector.Connection
cursor int
width int
height int
// filtering
showTCP bool
showUDP bool
showListening bool
showEstablished bool
showOther bool
searchQuery string
searchActive bool
// sorting
sortField collector.SortField
sortReverse bool
// ui state
theme *theme.Theme
showHelp bool
showDetail bool
selected *collector.Connection
interval time.Duration
lastRefresh time.Time
err error
}
type Options struct {
Theme string
Interval time.Duration
TCP bool
UDP bool
Listening bool
Established bool
Other bool
FilterSet bool // true if user specified any filter flags
}
func New(opts Options) model {
interval := opts.Interval
if interval == 0 {
interval = time.Second
}
// default: show everything
showTCP := true
showUDP := true
showListening := true
showEstablished := true
showOther := true
// if user specified filters, use those instead
if opts.FilterSet {
showTCP = opts.TCP
showUDP = opts.UDP
showListening = opts.Listening
showEstablished = opts.Established
showOther = opts.Other
// if only proto filters set, show all states
if !opts.Listening && !opts.Established && !opts.Other {
showListening = true
showEstablished = true
showOther = true
}
// if only state filters set, show all protos
if !opts.TCP && !opts.UDP {
showTCP = true
showUDP = true
}
}
return model{
connections: []collector.Connection{},
showTCP: showTCP,
showUDP: showUDP,
showListening: showListening,
showEstablished: showEstablished,
showOther: showOther,
sortField: collector.SortByLport,
theme: theme.GetTheme(opts.Theme),
interval: interval,
lastRefresh: time.Now(),
}
}
func (m model) Init() tea.Cmd {
return tea.Batch(
tea.HideCursor,
m.fetchData(),
m.tick(),
)
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
return m, nil
case tea.KeyMsg:
return m.handleKey(msg)
case tickMsg:
return m, tea.Batch(m.fetchData(), m.tick())
case dataMsg:
m.connections = msg.connections
m.lastRefresh = time.Now()
m.applySorting()
m.clampCursor()
return m, nil
case errMsg:
m.err = msg.err
return m, nil
}
return m, nil
}
func (m model) View() string {
if m.err != nil {
return m.renderError()
}
if m.showHelp {
return m.renderHelp()
}
if m.showDetail && m.selected != nil {
return m.renderDetail()
}
return m.renderMain()
}
func (m *model) applySorting() {
direction := collector.SortAsc
if m.sortReverse {
direction = collector.SortDesc
}
collector.SortConnections(m.connections, collector.SortOptions{
Field: m.sortField,
Direction: direction,
})
}
func (m *model) clampCursor() {
visible := m.visibleConnections()
if m.cursor >= len(visible) {
m.cursor = len(visible) - 1
}
if m.cursor < 0 {
m.cursor = 0
}
}
func (m model) visibleConnections() []collector.Connection {
var result []collector.Connection
for _, c := range m.connections {
if !m.matchesFilters(c) {
continue
}
if m.searchQuery != "" && !m.matchesSearch(c) {
continue
}
result = append(result, c)
}
return result
}
func (m model) matchesFilters(c collector.Connection) bool {
isTCP := c.Proto == "tcp" || c.Proto == "tcp6"
isUDP := c.Proto == "udp" || c.Proto == "udp6"
if isTCP && !m.showTCP {
return false
}
if isUDP && !m.showUDP {
return false
}
isListening := c.State == "LISTEN"
isEstablished := c.State == "ESTABLISHED"
isOther := !isListening && !isEstablished
if isListening && !m.showListening {
return false
}
if isEstablished && !m.showEstablished {
return false
}
if isOther && !m.showOther {
return false
}
return true
}
func (m model) matchesSearch(c collector.Connection) bool {
return containsIgnoreCase(c.Process, m.searchQuery) ||
containsIgnoreCase(c.Laddr, m.searchQuery) ||
containsIgnoreCase(c.Raddr, m.searchQuery) ||
containsIgnoreCase(c.User, m.searchQuery) ||
containsIgnoreCase(c.Proto, m.searchQuery) ||
containsIgnoreCase(c.State, m.searchQuery)
}

352
internal/tui/view.go Normal file
View File

@@ -0,0 +1,352 @@
package tui
import (
"fmt"
"snitch/internal/collector"
"strings"
"time"
)
func (m model) renderMain() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(m.renderTitle())
b.WriteString("\n")
b.WriteString(m.renderFilters())
b.WriteString("\n\n")
b.WriteString(m.renderTableHeader())
b.WriteString(m.renderSeparator())
b.WriteString(m.renderConnections())
b.WriteString("\n")
b.WriteString(m.renderStatusLine())
return b.String()
}
func (m model) renderTitle() string {
visible := m.visibleConnections()
total := len(m.connections)
left := m.theme.Styles.Header.Render("snitch")
ago := time.Since(m.lastRefresh).Round(time.Millisecond * 100)
right := m.theme.Styles.Normal.Render(fmt.Sprintf("%d/%d connections ↻ %s", len(visible), total, formatDuration(ago)))
w := m.safeWidth()
gap := w - len(stripAnsi(left)) - len(stripAnsi(right)) - 2
if gap < 0 {
gap = 0
}
return " " + left + strings.Repeat(" ", gap) + right
}
func (m model) renderFilters() string {
var parts []string
if m.showTCP {
parts = append(parts, m.theme.Styles.Success.Render("tcp"))
} else {
parts = append(parts, m.theme.Styles.Normal.Render("tcp"))
}
if m.showUDP {
parts = append(parts, m.theme.Styles.Success.Render("udp"))
} else {
parts = append(parts, m.theme.Styles.Normal.Render("udp"))
}
parts = append(parts, m.theme.Styles.Border.Render("│"))
if m.showListening {
parts = append(parts, m.theme.Styles.Success.Render("listen"))
} else {
parts = append(parts, m.theme.Styles.Normal.Render("listen"))
}
if m.showEstablished {
parts = append(parts, m.theme.Styles.Success.Render("estab"))
} else {
parts = append(parts, m.theme.Styles.Normal.Render("estab"))
}
if m.showOther {
parts = append(parts, m.theme.Styles.Success.Render("other"))
} else {
parts = append(parts, m.theme.Styles.Normal.Render("other"))
}
left := " " + strings.Join(parts, " ")
sortLabel := sortFieldLabel(m.sortField)
sortDir := "↑"
if m.sortReverse {
sortDir = "↓"
}
var right string
if m.searchActive {
right = m.theme.Styles.Warning.Render(fmt.Sprintf("/%s▌", m.searchQuery))
} else if m.searchQuery != "" {
right = m.theme.Styles.Normal.Render(fmt.Sprintf("filter: %s", m.searchQuery))
} else {
right = m.theme.Styles.Normal.Render(fmt.Sprintf("sort: %s %s", sortLabel, sortDir))
}
w := m.safeWidth()
gap := w - len(stripAnsi(left)) - len(stripAnsi(right)) - 2
if gap < 0 {
gap = 0
}
return left + strings.Repeat(" ", gap) + right + " "
}
func (m model) renderTableHeader() string {
cols := m.columnWidths()
header := fmt.Sprintf(" %-*s %-*s %-*s %-*s %-*s %s",
cols.process, "PROCESS",
cols.port, "PORT",
cols.proto, "PROTO",
cols.state, "STATE",
cols.local, "LOCAL",
"REMOTE")
return m.theme.Styles.Header.Render(header) + "\n"
}
func (m model) renderSeparator() string {
w := m.width - 4
if w < 1 {
w = 76
}
line := " " + strings.Repeat("─", w)
return m.theme.Styles.Border.Render(line) + "\n"
}
func (m model) renderConnections() string {
var b strings.Builder
visible := m.visibleConnections()
pageSize := m.pageSize()
if len(visible) == 0 {
empty := "\n " + m.theme.Styles.Normal.Render("no connections match filters") + "\n"
return empty
}
start := m.scrollOffset(pageSize, len(visible))
for i := 0; i < pageSize; i++ {
idx := start + i
if idx >= len(visible) {
b.WriteString("\n")
continue
}
isSelected := idx == m.cursor
b.WriteString(m.renderRow(visible[idx], isSelected))
}
return b.String()
}
func (m model) renderRow(c collector.Connection, selected bool) string {
cols := m.columnWidths()
indicator := " "
if selected {
indicator = m.theme.Styles.Success.Render("▸ ")
}
process := truncate(c.Process, cols.process)
if process == "" {
process = ""
}
port := fmt.Sprintf("%d", c.Lport)
proto := c.Proto
state := c.State
if state == "" {
state = ""
}
local := c.Laddr
if local == "*" || local == "" {
local = "*"
}
remote := formatRemote(c.Raddr, c.Rport)
// apply styling
protoStyled := m.theme.Styles.GetProtoStyle(proto).Render(fmt.Sprintf("%-*s", cols.proto, proto))
stateStyled := m.theme.Styles.GetStateStyle(state).Render(fmt.Sprintf("%-*s", cols.state, truncate(state, cols.state)))
row := fmt.Sprintf("%s%-*s %-*s %s %s %-*s %s",
indicator,
cols.process, process,
cols.port, port,
protoStyled,
stateStyled,
cols.local, truncate(local, cols.local),
truncate(remote, cols.remote))
if selected {
return m.theme.Styles.Selected.Render(row) + "\n"
}
return m.theme.Styles.Normal.Render(row) + "\n"
}
func (m model) renderStatusLine() string {
left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state s sort / search ? help q quit")
return left
}
func (m model) renderError() string {
return fmt.Sprintf("\n %s\n\n press q to quit\n",
m.theme.Styles.Error.Render(fmt.Sprintf("error: %v", m.err)))
}
func (m model) renderHelp() string {
help := `
navigation
──────────
j/k ↑/↓ move cursor
g/G jump to top/bottom
ctrl+d/u half page down/up
enter show connection details
filters
───────
t toggle tcp
u toggle udp
l toggle listening
e toggle established
o toggle other states
a reset all filters
sorting
───────
s cycle sort field
S reverse sort order
other
─────
/ search
r refresh now
q quit
press ? or esc to close
`
return m.theme.Styles.Normal.Render(help)
}
func (m model) renderDetail() string {
if m.selected == nil {
return ""
}
c := m.selected
var b strings.Builder
b.WriteString("\n")
b.WriteString(" " + m.theme.Styles.Header.Render("connection details") + "\n")
b.WriteString(" " + m.theme.Styles.Border.Render(strings.Repeat("─", 40)) + "\n\n")
fields := []struct {
label string
value string
}{
{"process", c.Process},
{"pid", fmt.Sprintf("%d", c.PID)},
{"user", c.User},
{"protocol", c.Proto},
{"state", c.State},
{"local", fmt.Sprintf("%s:%d", c.Laddr, c.Lport)},
{"remote", fmt.Sprintf("%s:%d", c.Raddr, c.Rport)},
{"interface", c.Interface},
{"inode", fmt.Sprintf("%d", c.Inode)},
}
for _, f := range fields {
val := f.value
if val == "" || val == "0" || val == ":0" {
val = ""
}
line := fmt.Sprintf(" %-12s %s\n", m.theme.Styles.Header.Render(f.label), val)
b.WriteString(line)
}
b.WriteString("\n")
b.WriteString(" " + m.theme.Styles.Normal.Render("press esc to close") + "\n")
return b.String()
}
func (m model) scrollOffset(pageSize, total int) int {
if total <= pageSize {
return 0
}
// keep cursor roughly centered
offset := m.cursor - pageSize/2
if offset < 0 {
offset = 0
}
if offset > total-pageSize {
offset = total - pageSize
}
return offset
}
type columns struct {
process int
port int
proto int
state int
local int
remote int
}
func (m model) columnWidths() columns {
available := m.safeWidth() - 16
c := columns{
process: 16,
port: 6,
proto: 5,
state: 11,
local: 15,
remote: 20,
}
used := c.process + c.port + c.proto + c.state + c.local + c.remote
extra := available - used
if extra > 0 {
c.process += extra / 3
c.remote += extra - extra/3
}
return c
}
func (m model) safeWidth() int {
if m.width < 80 {
return 80
}
return m.width
}
func formatDuration(d time.Duration) string {
if d < time.Second {
return fmt.Sprintf("%dms", d.Milliseconds())
}
if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
}
return fmt.Sprintf("%.0fm", d.Minutes())
}