feat: introduce theme management and performance improvements (#7)

This commit is contained in:
Karol Broda
2025-12-24 10:49:03 +01:00
committed by Karol Broda
parent 9b09ee7c6c
commit 933da8bad2
20 changed files with 877 additions and 172 deletions

View File

@@ -11,21 +11,76 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// set SNITCH_DEBUG_TIMING=1 to enable timing diagnostics
var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != ""
func logTiming(label string, start time.Time, extra ...string) {
if !debugTiming {
return
}
elapsed := time.Since(start)
if len(extra) > 0 {
fmt.Fprintf(os.Stderr, "[timing] %s: %v (%s)\n", label, elapsed, extra[0])
} else {
fmt.Fprintf(os.Stderr, "[timing] %s: %v\n", label, elapsed)
}
}
// userCache caches uid to username mappings to avoid repeated lookups
var userCache = struct {
sync.RWMutex
m map[int]string
}{m: make(map[int]string)}
func lookupUsername(uid int) string {
userCache.RLock()
if username, exists := userCache.m[uid]; exists {
userCache.RUnlock()
return username
}
userCache.RUnlock()
start := time.Now()
username := strconv.Itoa(uid)
u, err := user.LookupId(strconv.Itoa(uid))
if err == nil && u != nil {
username = u.Username
}
elapsed := time.Since(start)
if debugTiming && elapsed > 10*time.Millisecond {
fmt.Fprintf(os.Stderr, "[timing] user.LookupId(%d) slow: %v\n", uid, elapsed)
}
userCache.Lock()
userCache.m[uid] = username
userCache.Unlock()
return username
}
// DefaultCollector implements the Collector interface using /proc filesystem
type DefaultCollector struct{}
// GetConnections fetches all network connections by parsing /proc files
func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
totalStart := time.Now()
defer func() { logTiming("GetConnections total", totalStart) }()
inodeStart := time.Now()
inodeMap, err := buildInodeToProcessMap()
logTiming("buildInodeToProcessMap", inodeStart, fmt.Sprintf("%d inodes", len(inodeMap)))
if err != nil {
return nil, fmt.Errorf("failed to build inode map: %w", err)
}
var connections []Connection
parseStart := time.Now()
tcpConns, err := parseProcNet("/proc/net/tcp", "tcp", 4, inodeMap)
if err == nil {
connections = append(connections, tcpConns...)
@@ -45,6 +100,7 @@ func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
if err == nil {
connections = append(connections, udpConns6...)
}
logTiming("parseProcNet (all)", parseStart, fmt.Sprintf("%d connections", len(connections)))
return connections, nil
}
@@ -71,9 +127,13 @@ type processInfo struct {
user string
}
func buildInodeToProcessMap() (map[int64]*processInfo, error) {
inodeMap := make(map[int64]*processInfo)
type inodeEntry struct {
inode int64
info *processInfo
}
func buildInodeToProcessMap() (map[int64]*processInfo, error) {
readDirStart := time.Now()
procDir, err := os.Open("/proc")
if err != nil {
return nil, err
@@ -85,47 +145,103 @@ func buildInodeToProcessMap() (map[int64]*processInfo, error) {
return nil, err
}
// collect pids first
pids := make([]int, 0, len(entries))
for _, entry := range entries {
if !entry.IsDir() {
continue
}
pid, err := strconv.Atoi(entry.Name())
if err != nil {
continue
}
pids = append(pids, pid)
}
logTiming(" readdir /proc", readDirStart, fmt.Sprintf("%d pids", len(pids)))
pidStr := entry.Name()
pid, err := strconv.Atoi(pidStr)
// process pids in parallel with limited concurrency
scanStart := time.Now()
const numWorkers = 8
pidChan := make(chan int, len(pids))
resultChan := make(chan []inodeEntry, len(pids))
var totalFDs atomic.Int64
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for pid := range pidChan {
entries := scanProcessSockets(pid)
if len(entries) > 0 {
totalFDs.Add(int64(len(entries)))
resultChan <- entries
}
}
}()
}
for _, pid := range pids {
pidChan <- pid
}
close(pidChan)
go func() {
wg.Wait()
close(resultChan)
}()
inodeMap := make(map[int64]*processInfo)
for entries := range resultChan {
for _, e := range entries {
inodeMap[e.inode] = e.info
}
}
logTiming(" scan all processes", scanStart, fmt.Sprintf("%d socket fds scanned", totalFDs.Load()))
return inodeMap, nil
}
func scanProcessSockets(pid int) []inodeEntry {
start := time.Now()
procInfo, err := getProcessInfo(pid)
if err != nil {
return nil
}
pidStr := strconv.Itoa(pid)
fdDir := filepath.Join("/proc", pidStr, "fd")
fdEntries, err := os.ReadDir(fdDir)
if err != nil {
return nil
}
var results []inodeEntry
for _, fdEntry := range fdEntries {
fdPath := filepath.Join(fdDir, fdEntry.Name())
link, err := os.Readlink(fdPath)
if err != nil {
continue
}
procInfo, err := getProcessInfo(pid)
if err != nil {
continue
}
fdDir := filepath.Join("/proc", pidStr, "fd")
fdEntries, err := os.ReadDir(fdDir)
if err != nil {
continue
}
for _, fdEntry := range fdEntries {
fdPath := filepath.Join(fdDir, fdEntry.Name())
link, err := os.Readlink(fdPath)
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
inodeStr := link[8 : len(link)-1]
inode, err := strconv.ParseInt(inodeStr, 10, 64)
if err != nil {
continue
}
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
inodeStr := link[8 : len(link)-1]
inode, err := strconv.ParseInt(inodeStr, 10, 64)
if err != nil {
continue
}
inodeMap[inode] = procInfo
}
results = append(results, inodeEntry{inode: inode, info: procInfo})
}
}
return inodeMap, nil
elapsed := time.Since(start)
if debugTiming && elapsed > 20*time.Millisecond {
fmt.Fprintf(os.Stderr, "[timing] slow process scan: pid=%d (%s) fds=%d time=%v\n",
pid, procInfo.command, len(fdEntries), elapsed)
}
return results
}
func getProcessInfo(pid int) (*processInfo, error) {
@@ -173,12 +289,7 @@ func getProcessInfo(pid int) (*processInfo, error) {
uid, err := strconv.Atoi(fields[1])
if err == nil {
info.uid = uid
u, err := user.LookupId(strconv.Itoa(uid))
if err == nil {
info.user = u.Username
} else {
info.user = strconv.Itoa(uid)
}
info.user = lookupUsername(uid)
}
}
break

View File

@@ -1,7 +1,10 @@
//go:build linux
package collector
import (
"testing"
"time"
)
func TestGetConnections(t *testing.T) {
@@ -13,4 +16,102 @@ func TestGetConnections(t *testing.T) {
// connections are dynamic, so just verify function succeeded
t.Logf("Successfully got %d connections", len(conns))
}
func TestGetConnectionsPerformance(t *testing.T) {
// measures performance to catch regressions
// run with: go test -v -run TestGetConnectionsPerformance
const maxDuration = 500 * time.Millisecond
const iterations = 5
// warm up caches first
_, err := GetConnections()
if err != nil {
t.Fatalf("warmup failed: %v", err)
}
var total time.Duration
var maxSeen time.Duration
for i := 0; i < iterations; i++ {
start := time.Now()
conns, err := GetConnections()
elapsed := time.Since(start)
if err != nil {
t.Fatalf("iteration %d failed: %v", i, err)
}
total += elapsed
if elapsed > maxSeen {
maxSeen = elapsed
}
t.Logf("iteration %d: %v (%d connections)", i+1, elapsed, len(conns))
}
avg := total / time.Duration(iterations)
t.Logf("average: %v, max: %v", avg, maxSeen)
if maxSeen > maxDuration {
t.Errorf("slowest iteration took %v, expected < %v", maxSeen, maxDuration)
}
}
func TestGetConnectionsColdCache(t *testing.T) {
// tests performance with cold user cache
// this simulates first run or after cache invalidation
const maxDuration = 2 * time.Second
clearUserCache()
start := time.Now()
conns, err := GetConnections()
elapsed := time.Since(start)
if err != nil {
t.Fatalf("GetConnections() failed: %v", err)
}
t.Logf("cold cache: %v (%d connections, %d cached users after)",
elapsed, len(conns), userCacheSize())
if elapsed > maxDuration {
t.Errorf("cold cache took %v, expected < %v", elapsed, maxDuration)
}
}
func BenchmarkGetConnections(b *testing.B) {
// warm cache benchmark - measures typical runtime
// run with: go test -bench=BenchmarkGetConnections -benchtime=5s
// warm up
_, _ = GetConnections()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = GetConnections()
}
}
func BenchmarkGetConnectionsColdCache(b *testing.B) {
// cold cache benchmark - measures worst-case with cache cleared each iteration
// run with: go test -bench=BenchmarkGetConnectionsColdCache -benchtime=10s
b.ResetTimer()
for i := 0; i < b.N; i++ {
clearUserCache()
_, _ = GetConnections()
}
}
func BenchmarkBuildInodeMap(b *testing.B) {
// benchmarks just the inode map building (most expensive part)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = buildInodeToProcessMap()
}
}

View File

@@ -0,0 +1,18 @@
//go:build linux
package collector
// clearUserCache clears the user lookup cache for testing
func clearUserCache() {
userCache.Lock()
userCache.m = make(map[int]string)
userCache.Unlock()
}
// userCacheSize returns the number of cached user entries
func userCacheSize() int {
userCache.RLock()
defer userCache.RUnlock()
return len(userCache.m)
}