initial commit
This commit is contained in:
506
internal/collector/collector.go
Normal file
506
internal/collector/collector.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Collector interface defines methods for collecting connection data
|
||||
type Collector interface {
|
||||
GetConnections() ([]Connection, error)
|
||||
}
|
||||
|
||||
// DefaultCollector implements the Collector interface using /proc
|
||||
type DefaultCollector struct{}
|
||||
|
||||
// Global collector instance (can be overridden for testing)
|
||||
var globalCollector Collector = &DefaultCollector{}
|
||||
|
||||
// SetCollector sets the global collector instance
|
||||
func SetCollector(collector Collector) {
|
||||
globalCollector = collector
|
||||
}
|
||||
|
||||
// GetCollector returns the current global collector instance
|
||||
func GetCollector() Collector {
|
||||
return globalCollector
|
||||
}
|
||||
|
||||
// GetConnections fetches all network connections using the global collector
|
||||
func GetConnections() ([]Connection, error) {
|
||||
return globalCollector.GetConnections()
|
||||
}
|
||||
|
||||
// GetConnections fetches all network connections by parsing /proc files.
|
||||
func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil, fmt.Errorf("proc-based collector only supports Linux")
|
||||
}
|
||||
|
||||
// Build map of inode -> process info by scanning /proc
|
||||
inodeMap, err := buildInodeToProcessMap()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build inode map: %w", err)
|
||||
}
|
||||
|
||||
var connections []Connection
|
||||
|
||||
// Parse TCP connections
|
||||
tcpConns, err := parseProcNet("/proc/net/tcp", "tcp", 4, inodeMap)
|
||||
if err == nil {
|
||||
connections = append(connections, tcpConns...)
|
||||
}
|
||||
|
||||
tcpConns6, err := parseProcNet("/proc/net/tcp6", "tcp6", 6, inodeMap)
|
||||
if err == nil {
|
||||
connections = append(connections, tcpConns6...)
|
||||
}
|
||||
|
||||
// Parse UDP connections
|
||||
udpConns, err := parseProcNet("/proc/net/udp", "udp", 4, inodeMap)
|
||||
if err == nil {
|
||||
connections = append(connections, udpConns...)
|
||||
}
|
||||
|
||||
udpConns6, err := parseProcNet("/proc/net/udp6", "udp6", 6, inodeMap)
|
||||
if err == nil {
|
||||
connections = append(connections, udpConns6...)
|
||||
}
|
||||
|
||||
return connections, nil
|
||||
}
|
||||
|
||||
// GetAllConnections returns both network and Unix domain socket connections
|
||||
func GetAllConnections() ([]Connection, error) {
|
||||
// Get network connections
|
||||
networkConns, err := GetConnections()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get Unix sockets (only on Linux)
|
||||
if runtime.GOOS == "linux" {
|
||||
unixConns, err := GetUnixSockets()
|
||||
if err == nil {
|
||||
networkConns = append(networkConns, unixConns...)
|
||||
}
|
||||
}
|
||||
|
||||
return networkConns, nil
|
||||
}
|
||||
|
||||
func FilterConnections(conns []Connection, filters FilterOptions) []Connection {
|
||||
if filters.IsEmpty() {
|
||||
return conns
|
||||
}
|
||||
|
||||
filtered := make([]Connection, 0)
|
||||
for _, conn := range conns {
|
||||
if filters.Matches(conn) {
|
||||
filtered = append(filtered, conn)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// processInfo holds information about a process
|
||||
type processInfo struct {
|
||||
pid int
|
||||
command string
|
||||
uid int
|
||||
user string
|
||||
}
|
||||
|
||||
// buildInodeToProcessMap scans /proc to map socket inodes to processes
|
||||
func buildInodeToProcessMap() (map[int64]*processInfo, error) {
|
||||
inodeMap := make(map[int64]*processInfo)
|
||||
|
||||
procDir, err := os.Open("/proc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer procDir.Close()
|
||||
|
||||
entries, err := procDir.Readdir(-1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// check if directory name is a number (pid)
|
||||
pidStr := entry.Name()
|
||||
pid, err := strconv.Atoi(pidStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// get process info
|
||||
procInfo, err := getProcessInfo(pid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// scan /proc/[pid]/fd/ for socket file descriptors
|
||||
fdDir := filepath.Join("/proc", pidStr, "fd")
|
||||
fdEntries, err := os.ReadDir(fdDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, fdEntry := range fdEntries {
|
||||
fdPath := filepath.Join(fdDir, fdEntry.Name())
|
||||
link, err := os.Readlink(fdPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// socket inodes look like: socket:[12345]
|
||||
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
|
||||
inodeStr := link[8 : len(link)-1]
|
||||
inode, err := strconv.ParseInt(inodeStr, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
inodeMap[inode] = procInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inodeMap, nil
|
||||
}
|
||||
|
||||
// getProcessInfo reads process information from /proc/[pid]/
|
||||
func getProcessInfo(pid int) (*processInfo, error) {
|
||||
info := &processInfo{pid: pid}
|
||||
|
||||
// prefer /proc/[pid]/comm as it's always just the command name
|
||||
commPath := filepath.Join("/proc", strconv.Itoa(pid), "comm")
|
||||
commData, err := os.ReadFile(commPath)
|
||||
if err == nil && len(commData) > 0 {
|
||||
info.command = strings.TrimSpace(string(commData))
|
||||
}
|
||||
|
||||
// if comm is not available, try cmdline
|
||||
if info.command == "" {
|
||||
cmdlinePath := filepath.Join("/proc", strconv.Itoa(pid), "cmdline")
|
||||
cmdlineData, err := os.ReadFile(cmdlinePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// cmdline is null-separated, take first part
|
||||
if len(cmdlineData) > 0 {
|
||||
parts := bytes.Split(cmdlineData, []byte{0})
|
||||
if len(parts) > 0 && len(parts[0]) > 0 {
|
||||
fullPath := string(parts[0])
|
||||
// extract basename from full path
|
||||
baseName := filepath.Base(fullPath)
|
||||
// if basename contains spaces (single-string cmdline), take first word
|
||||
if strings.Contains(baseName, " ") {
|
||||
baseName = strings.Fields(baseName)[0]
|
||||
}
|
||||
info.command = baseName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// read UID from /proc/[pid]/status
|
||||
statusPath := filepath.Join("/proc", strconv.Itoa(pid), "status")
|
||||
statusFile, err := os.Open(statusPath)
|
||||
if err != nil {
|
||||
return info, nil
|
||||
}
|
||||
defer statusFile.Close()
|
||||
|
||||
scanner := bufio.NewScanner(statusFile)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "Uid:") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
uid, err := strconv.Atoi(fields[1])
|
||||
if err == nil {
|
||||
info.uid = uid
|
||||
// get username from uid
|
||||
u, err := user.LookupId(strconv.Itoa(uid))
|
||||
if err == nil {
|
||||
info.user = u.Username
|
||||
} else {
|
||||
info.user = strconv.Itoa(uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// parseProcNet parses a /proc/net/tcp or /proc/net/udp file
|
||||
func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*processInfo) ([]Connection, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var connections []Connection
|
||||
scanner := bufio.NewScanner(file)
|
||||
|
||||
// skip header
|
||||
scanner.Scan()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 10 {
|
||||
continue
|
||||
}
|
||||
|
||||
// parse local address and port
|
||||
localAddr, localPort, err := parseHexAddr(fields[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// parse remote address and port
|
||||
remoteAddr, remotePort, err := parseHexAddr(fields[2])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// parse state (field 3)
|
||||
stateHex := fields[3]
|
||||
state := parseState(stateHex, proto)
|
||||
|
||||
// parse inode (field 9)
|
||||
inode, _ := strconv.ParseInt(fields[9], 10, 64)
|
||||
|
||||
conn := Connection{
|
||||
TS: time.Now(),
|
||||
Proto: proto,
|
||||
IPVersion: fmt.Sprintf("IPv%d", ipVersion),
|
||||
State: state,
|
||||
Laddr: localAddr,
|
||||
Lport: localPort,
|
||||
Raddr: remoteAddr,
|
||||
Rport: remotePort,
|
||||
Inode: inode,
|
||||
}
|
||||
|
||||
// add process info if available
|
||||
if procInfo, exists := inodeMap[inode]; exists {
|
||||
conn.PID = procInfo.pid
|
||||
conn.Process = procInfo.command
|
||||
conn.UID = procInfo.uid
|
||||
conn.User = procInfo.user
|
||||
}
|
||||
|
||||
// determine interface
|
||||
conn.Interface = guessNetworkInterface(localAddr, nil)
|
||||
|
||||
connections = append(connections, conn)
|
||||
}
|
||||
|
||||
return connections, scanner.Err()
|
||||
}
|
||||
|
||||
// parseState converts hex state value to string
|
||||
func parseState(hexState, proto string) string {
|
||||
state, err := strconv.ParseInt(hexState, 16, 32)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TCP states
|
||||
tcpStates := map[int64]string{
|
||||
0x01: "ESTABLISHED",
|
||||
0x02: "SYN_SENT",
|
||||
0x03: "SYN_RECV",
|
||||
0x04: "FIN_WAIT1",
|
||||
0x05: "FIN_WAIT2",
|
||||
0x06: "TIME_WAIT",
|
||||
0x07: "CLOSE",
|
||||
0x08: "CLOSE_WAIT",
|
||||
0x09: "LAST_ACK",
|
||||
0x0A: "LISTEN",
|
||||
0x0B: "CLOSING",
|
||||
}
|
||||
|
||||
if strings.HasPrefix(proto, "tcp") {
|
||||
if s, exists := tcpStates[state]; exists {
|
||||
return s
|
||||
}
|
||||
} else {
|
||||
// UDP doesn't have states in the same way
|
||||
if state == 0x07 {
|
||||
return "CLOSE"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseHexAddr parses hex-encoded address:port from /proc/net files
|
||||
func parseHexAddr(hexAddr string) (string, int, error) {
|
||||
parts := strings.Split(hexAddr, ":")
|
||||
if len(parts) != 2 {
|
||||
return "", 0, fmt.Errorf("invalid address format")
|
||||
}
|
||||
|
||||
hexIP := parts[0]
|
||||
|
||||
// parse hex port
|
||||
port, err := strconv.ParseInt(parts[1], 16, 32)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if len(hexIP) == 8 {
|
||||
// IPv4 (stored in little-endian)
|
||||
ip1, _ := strconv.ParseInt(hexIP[6:8], 16, 32)
|
||||
ip2, _ := strconv.ParseInt(hexIP[4:6], 16, 32)
|
||||
ip3, _ := strconv.ParseInt(hexIP[2:4], 16, 32)
|
||||
ip4, _ := strconv.ParseInt(hexIP[0:2], 16, 32)
|
||||
addr := fmt.Sprintf("%d.%d.%d.%d", ip1, ip2, ip3, ip4)
|
||||
|
||||
// handle wildcard address
|
||||
if addr == "0.0.0.0" {
|
||||
addr = "*"
|
||||
}
|
||||
|
||||
return addr, int(port), nil
|
||||
} else if len(hexIP) == 32 {
|
||||
// IPv6 (stored in little-endian per 32-bit word)
|
||||
var ipv6Parts []string
|
||||
for i := 0; i < 32; i += 8 {
|
||||
word := hexIP[i : i+8]
|
||||
// reverse byte order within each 32-bit word
|
||||
p1 := word[6:8] + word[4:6] + word[2:4] + word[0:2]
|
||||
ipv6Parts = append(ipv6Parts, p1)
|
||||
}
|
||||
|
||||
// convert to standard IPv6 notation
|
||||
fullAddr := strings.Join(ipv6Parts, "")
|
||||
var formatted []string
|
||||
for i := 0; i < len(fullAddr); i += 4 {
|
||||
formatted = append(formatted, fullAddr[i:i+4])
|
||||
}
|
||||
addr := strings.Join(formatted, ":")
|
||||
|
||||
// simplify IPv6 address
|
||||
addr = simplifyIPv6(addr)
|
||||
|
||||
// handle wildcard address
|
||||
if addr == "::" || addr == "0:0:0:0:0:0:0:0" {
|
||||
addr = "*"
|
||||
}
|
||||
|
||||
return addr, int(port), nil
|
||||
}
|
||||
|
||||
return "", 0, fmt.Errorf("unsupported address format")
|
||||
}
|
||||
|
||||
// simplifyIPv6 simplifies IPv6 address notation
|
||||
func simplifyIPv6(addr string) string {
|
||||
// remove leading zeros from each group
|
||||
parts := strings.Split(addr, ":")
|
||||
for i, part := range parts {
|
||||
// convert to int and back to remove leading zeros
|
||||
val, err := strconv.ParseInt(part, 16, 64)
|
||||
if err == nil {
|
||||
parts[i] = strconv.FormatInt(val, 16)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, ":")
|
||||
}
|
||||
|
||||
func guessNetworkInterface(addr string, interfaces map[string]string) string {
|
||||
// Simple heuristic - try to match common interface patterns
|
||||
if addr == "127.0.0.1" || addr == "::1" {
|
||||
return "lo"
|
||||
}
|
||||
|
||||
// Check if it's a private network address
|
||||
ip := net.ParseIP(addr)
|
||||
if ip != nil {
|
||||
if ip.IsLoopback() {
|
||||
return "lo"
|
||||
}
|
||||
// More sophisticated interface detection would require routing table analysis
|
||||
// For now, return a placeholder
|
||||
if ip.To4() != nil {
|
||||
return "eth0" // Common default for IPv4
|
||||
} else {
|
||||
return "eth0" // Common default for IPv6
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Add Unix socket support
|
||||
func GetUnixSockets() ([]Connection, error) {
|
||||
connections := []Connection{}
|
||||
|
||||
// Parse /proc/net/unix for Unix domain sockets
|
||||
file, err := os.Open("/proc/net/unix")
|
||||
if err != nil {
|
||||
return connections, nil // silently fail on non-Linux systems
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
// Skip header
|
||||
scanner.Scan()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 7 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse Unix socket information
|
||||
inode, _ := strconv.ParseInt(fields[6], 10, 64)
|
||||
path := ""
|
||||
if len(fields) > 7 {
|
||||
path = fields[7]
|
||||
}
|
||||
|
||||
conn := Connection{
|
||||
TS: time.Now(),
|
||||
Proto: "unix",
|
||||
Laddr: path,
|
||||
Raddr: "",
|
||||
State: "CONNECTED", // Simplified
|
||||
Inode: inode,
|
||||
Interface: "unix",
|
||||
}
|
||||
|
||||
connections = append(connections, conn)
|
||||
}
|
||||
|
||||
return connections, nil
|
||||
}
|
||||
|
||||
16
internal/collector/collector_test.go
Normal file
16
internal/collector/collector_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetConnections(t *testing.T) {
|
||||
// integration test to verify /proc parsing works
|
||||
conns, err := GetConnections()
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnections() returned an error: %v", err)
|
||||
}
|
||||
|
||||
// connections are dynamic, so just verify function succeeded
|
||||
t.Logf("Successfully got %d connections", len(conns))
|
||||
}
|
||||
129
internal/collector/filter.go
Normal file
129
internal/collector/filter.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FilterOptions struct {
|
||||
Proto string
|
||||
State string
|
||||
Pid int
|
||||
Proc string
|
||||
Lport int
|
||||
Rport int
|
||||
User string
|
||||
UID int
|
||||
Laddr string
|
||||
Raddr string
|
||||
Contains string
|
||||
IPv4 bool
|
||||
IPv6 bool
|
||||
Interface string
|
||||
Mark string
|
||||
Namespace string
|
||||
Inode int64
|
||||
Since time.Time
|
||||
SinceRel time.Duration
|
||||
}
|
||||
|
||||
func (f *FilterOptions) IsEmpty() bool {
|
||||
return f.Proto == "" && f.State == "" && f.Pid == 0 && f.Proc == "" &&
|
||||
f.Lport == 0 && f.Rport == 0 && f.User == "" && f.UID == 0 &&
|
||||
f.Laddr == "" && f.Raddr == "" && f.Contains == "" &&
|
||||
f.Interface == "" && f.Mark == "" && f.Namespace == "" && f.Inode == 0 &&
|
||||
f.Since.IsZero() && f.SinceRel == 0 && !f.IPv4 && !f.IPv6
|
||||
}
|
||||
|
||||
func (f *FilterOptions) Matches(c Connection) bool {
|
||||
if f.Proto != "" && !strings.EqualFold(c.Proto, f.Proto) {
|
||||
return false
|
||||
}
|
||||
if f.State != "" && !strings.EqualFold(c.State, f.State) {
|
||||
return false
|
||||
}
|
||||
if f.Pid != 0 && c.PID != f.Pid {
|
||||
return false
|
||||
}
|
||||
if f.Proc != "" && !containsIgnoreCase(c.Process, f.Proc) {
|
||||
return false
|
||||
}
|
||||
if f.Lport != 0 && c.Lport != f.Lport {
|
||||
return false
|
||||
}
|
||||
if f.Rport != 0 && c.Rport != f.Rport {
|
||||
return false
|
||||
}
|
||||
if f.User != "" && !strings.EqualFold(c.User, f.User) {
|
||||
return false
|
||||
}
|
||||
if f.UID != 0 && c.UID != f.UID {
|
||||
return false
|
||||
}
|
||||
if f.Laddr != "" && !strings.EqualFold(c.Laddr, f.Laddr) {
|
||||
return false
|
||||
}
|
||||
if f.Raddr != "" && !strings.EqualFold(c.Raddr, f.Raddr) {
|
||||
return false
|
||||
}
|
||||
if f.Contains != "" && !matchesContains(c, f.Contains) {
|
||||
return false
|
||||
}
|
||||
if f.IPv4 && c.IPVersion != "IPv4" {
|
||||
return false
|
||||
}
|
||||
if f.IPv6 && c.IPVersion != "IPv6" {
|
||||
return false
|
||||
}
|
||||
if f.Interface != "" && !strings.EqualFold(c.Interface, f.Interface) {
|
||||
return false
|
||||
}
|
||||
if f.Mark != "" && !strings.EqualFold(c.Mark, f.Mark) {
|
||||
return false
|
||||
}
|
||||
if f.Namespace != "" && !strings.EqualFold(c.Namespace, f.Namespace) {
|
||||
return false
|
||||
}
|
||||
if f.Inode != 0 && c.Inode != f.Inode {
|
||||
return false
|
||||
}
|
||||
if !f.Since.IsZero() && c.TS.Before(f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.SinceRel != 0 {
|
||||
threshold := time.Now().Add(-f.SinceRel)
|
||||
if c.TS.Before(threshold) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
||||
}
|
||||
|
||||
func matchesContains(c Connection, query string) bool {
|
||||
q := strings.ToLower(query)
|
||||
return containsIgnoreCase(c.Process, q) ||
|
||||
containsIgnoreCase(c.Laddr, q) ||
|
||||
containsIgnoreCase(c.Raddr, q) ||
|
||||
containsIgnoreCase(c.User, q)
|
||||
}
|
||||
|
||||
// ParseTimeFilter parses a time filter string (RFC3339 or relative like "5s", "2m", "1h")
|
||||
func ParseTimeFilter(timeStr string) (time.Time, time.Duration, error) {
|
||||
// Try parsing as RFC3339 first
|
||||
if t, err := time.Parse(time.RFC3339, timeStr); err == nil {
|
||||
return t, 0, nil
|
||||
}
|
||||
|
||||
// Try parsing as relative duration
|
||||
if dur, err := time.ParseDuration(timeStr); err == nil {
|
||||
return time.Time{}, dur, nil
|
||||
}
|
||||
|
||||
return time.Time{}, 0, nil // Invalid format, but don't error
|
||||
}
|
||||
|
||||
44
internal/collector/filter_test.go
Normal file
44
internal/collector/filter_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterConnections(t *testing.T) {
|
||||
conns := []Connection{
|
||||
{PID: 1, Process: "proc1", User: "user1", Proto: "tcp", State: "ESTABLISHED", Laddr: "1.1.1.1", Lport: 80, Raddr: "2.2.2.2", Rport: 1234},
|
||||
{PID: 2, Process: "proc2", User: "user2", Proto: "udp", State: "LISTEN", Laddr: "3.3.3.3", Lport: 53, Raddr: "*", Rport: 0},
|
||||
{PID: 3, Process: "proc1_extra", User: "user1", Proto: "tcp", State: "ESTABLISHED", Laddr: "4.4.4.4", Lport: 443, Raddr: "5.5.5.5", Rport: 5678},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filters FilterOptions
|
||||
expected int
|
||||
}{
|
||||
{"No filters", FilterOptions{}, 3},
|
||||
{"Filter by proto tcp", FilterOptions{Proto: "tcp"}, 2},
|
||||
{"Filter by proto udp", FilterOptions{Proto: "udp"}, 1},
|
||||
{"Filter by state", FilterOptions{State: "ESTABLISHED"}, 2},
|
||||
{"Filter by pid", FilterOptions{Pid: 2}, 1},
|
||||
{"Filter by proc", FilterOptions{Proc: "proc1"}, 2},
|
||||
{"Filter by lport", FilterOptions{Lport: 80}, 1},
|
||||
{"Filter by rport", FilterOptions{Rport: 1234}, 1},
|
||||
{"Filter by user", FilterOptions{User: "user1"}, 2},
|
||||
{"Filter by laddr", FilterOptions{Laddr: "1.1.1.1"}, 1},
|
||||
{"Filter by raddr", FilterOptions{Raddr: "5.5.5.5"}, 1},
|
||||
{"Filter by contains proc", FilterOptions{Contains: "proc2"}, 1},
|
||||
{"Filter by contains addr", FilterOptions{Contains: "3.3.3.3"}, 1},
|
||||
{"Combined filter", FilterOptions{Proto: "tcp", State: "ESTABLISHED"}, 2},
|
||||
{"No match", FilterOptions{Proto: "tcp", State: "LISTEN"}, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
filtered := FilterConnections(conns, tc.filters)
|
||||
if len(filtered) != tc.expected {
|
||||
t.Errorf("Expected %d connections, but got %d", tc.expected, len(filtered))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
403
internal/collector/mock.go
Normal file
403
internal/collector/mock.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MockCollector provides deterministic test data for testing
|
||||
// It implements the Collector interface
|
||||
type MockCollector struct {
|
||||
connections []Connection
|
||||
}
|
||||
|
||||
// NewMockCollector creates a new mock collector with default test data
|
||||
func NewMockCollector() *MockCollector {
|
||||
return &MockCollector{
|
||||
connections: getDefaultTestConnections(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockCollectorFromFile creates a mock collector from a JSON fixture file
|
||||
func NewMockCollectorFromFile(filename string) (*MockCollector, error) {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var connections []Connection
|
||||
if err := json.Unmarshal(data, &connections); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MockCollector{
|
||||
connections: connections,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetConnections returns the mock connections
|
||||
func (m *MockCollector) GetConnections() ([]Connection, error) {
|
||||
// Return a copy to avoid mutation
|
||||
result := make([]Connection, len(m.connections))
|
||||
copy(result, m.connections)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AddConnection adds a connection to the mock data
|
||||
func (m *MockCollector) AddConnection(conn Connection) {
|
||||
m.connections = append(m.connections, conn)
|
||||
}
|
||||
|
||||
// SetConnections replaces all connections with the provided slice
|
||||
func (m *MockCollector) SetConnections(connections []Connection) {
|
||||
m.connections = make([]Connection, len(connections))
|
||||
copy(m.connections, connections)
|
||||
}
|
||||
|
||||
// SaveToFile saves the current connections to a JSON file
|
||||
func (m *MockCollector) SaveToFile(filename string) error {
|
||||
data, err := json.MarshalIndent(m.connections, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, data, 0644)
|
||||
}
|
||||
|
||||
// getDefaultTestConnections returns a set of deterministic test connections
|
||||
func getDefaultTestConnections() []Connection {
|
||||
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
|
||||
return []Connection{
|
||||
{
|
||||
TS: baseTime,
|
||||
PID: 1234,
|
||||
Process: "nginx",
|
||||
User: "www-data",
|
||||
UID: 33,
|
||||
Proto: "tcp",
|
||||
IPVersion: "IPv4",
|
||||
State: "LISTEN",
|
||||
Laddr: "0.0.0.0",
|
||||
Lport: 80,
|
||||
Raddr: "*",
|
||||
Rport: 0,
|
||||
Interface: "eth0",
|
||||
RxBytes: 0,
|
||||
TxBytes: 0,
|
||||
RttMs: 0,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 12345,
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(time.Second),
|
||||
PID: 1234,
|
||||
Process: "nginx",
|
||||
User: "www-data",
|
||||
UID: 33,
|
||||
Proto: "tcp",
|
||||
IPVersion: "IPv4",
|
||||
State: "ESTABLISHED",
|
||||
Laddr: "10.0.0.1",
|
||||
Lport: 80,
|
||||
Raddr: "203.0.113.10",
|
||||
Rport: 52344,
|
||||
Interface: "eth0",
|
||||
RxBytes: 10240,
|
||||
TxBytes: 2048,
|
||||
RttMs: 1.7,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 12346,
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(2 * time.Second),
|
||||
PID: 5678,
|
||||
Process: "postgres",
|
||||
User: "postgres",
|
||||
UID: 26,
|
||||
Proto: "tcp",
|
||||
IPVersion: "IPv4",
|
||||
State: "LISTEN",
|
||||
Laddr: "127.0.0.1",
|
||||
Lport: 5432,
|
||||
Raddr: "*",
|
||||
Rport: 0,
|
||||
Interface: "lo",
|
||||
RxBytes: 0,
|
||||
TxBytes: 0,
|
||||
RttMs: 0,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 12347,
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(3 * time.Second),
|
||||
PID: 5678,
|
||||
Process: "postgres",
|
||||
User: "postgres",
|
||||
UID: 26,
|
||||
Proto: "tcp",
|
||||
IPVersion: "IPv4",
|
||||
State: "ESTABLISHED",
|
||||
Laddr: "127.0.0.1",
|
||||
Lport: 5432,
|
||||
Raddr: "127.0.0.1",
|
||||
Rport: 45678,
|
||||
Interface: "lo",
|
||||
RxBytes: 8192,
|
||||
TxBytes: 4096,
|
||||
RttMs: 0.1,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 12348,
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(4 * time.Second),
|
||||
PID: 9999,
|
||||
Process: "dns-server",
|
||||
User: "named",
|
||||
UID: 25,
|
||||
Proto: "udp",
|
||||
IPVersion: "IPv4",
|
||||
State: "CONNECTED",
|
||||
Laddr: "0.0.0.0",
|
||||
Lport: 53,
|
||||
Raddr: "*",
|
||||
Rport: 0,
|
||||
Interface: "eth0",
|
||||
RxBytes: 1024,
|
||||
TxBytes: 512,
|
||||
RttMs: 0,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 12349,
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(5 * time.Second),
|
||||
PID: 1111,
|
||||
Process: "ssh",
|
||||
User: "root",
|
||||
UID: 0,
|
||||
Proto: "tcp",
|
||||
IPVersion: "IPv4",
|
||||
State: "ESTABLISHED",
|
||||
Laddr: "192.168.1.100",
|
||||
Lport: 22,
|
||||
Raddr: "192.168.1.200",
|
||||
Rport: 54321,
|
||||
Interface: "eth0",
|
||||
RxBytes: 2048,
|
||||
TxBytes: 1024,
|
||||
RttMs: 2.3,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 12350,
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(6 * time.Second),
|
||||
PID: 2222,
|
||||
Process: "app-server",
|
||||
User: "app",
|
||||
UID: 1000,
|
||||
Proto: "unix",
|
||||
IPVersion: "",
|
||||
State: "CONNECTED",
|
||||
Laddr: "/tmp/app.sock",
|
||||
Lport: 0,
|
||||
Raddr: "",
|
||||
Rport: 0,
|
||||
Interface: "unix",
|
||||
RxBytes: 512,
|
||||
TxBytes: 256,
|
||||
RttMs: 0,
|
||||
Mark: "",
|
||||
Namespace: "init",
|
||||
Inode: 12351,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionBuilder provides a fluent interface for building test connections
|
||||
type ConnectionBuilder struct {
|
||||
conn Connection
|
||||
}
|
||||
|
||||
// NewConnectionBuilder creates a new connection builder with sensible defaults
|
||||
func NewConnectionBuilder() *ConnectionBuilder {
|
||||
return &ConnectionBuilder{
|
||||
conn: Connection{
|
||||
TS: time.Now(),
|
||||
PID: 1000,
|
||||
Process: "test-process",
|
||||
User: "test-user",
|
||||
UID: 1000,
|
||||
Proto: "tcp",
|
||||
IPVersion: "IPv4",
|
||||
State: "ESTABLISHED",
|
||||
Laddr: "127.0.0.1",
|
||||
Lport: 8080,
|
||||
Raddr: "127.0.0.1",
|
||||
Rport: 9090,
|
||||
Interface: "lo",
|
||||
RxBytes: 1024,
|
||||
TxBytes: 512,
|
||||
RttMs: 1.0,
|
||||
Mark: "0x0",
|
||||
Namespace: "init",
|
||||
Inode: 99999,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithPID sets the PID
|
||||
func (b *ConnectionBuilder) WithPID(pid int) *ConnectionBuilder {
|
||||
b.conn.PID = pid
|
||||
return b
|
||||
}
|
||||
|
||||
// WithProcess sets the process name
|
||||
func (b *ConnectionBuilder) WithProcess(process string) *ConnectionBuilder {
|
||||
b.conn.Process = process
|
||||
return b
|
||||
}
|
||||
|
||||
// WithProto sets the protocol
|
||||
func (b *ConnectionBuilder) WithProto(proto string) *ConnectionBuilder {
|
||||
b.conn.Proto = proto
|
||||
return b
|
||||
}
|
||||
|
||||
// WithState sets the connection state
|
||||
func (b *ConnectionBuilder) WithState(state string) *ConnectionBuilder {
|
||||
b.conn.State = state
|
||||
return b
|
||||
}
|
||||
|
||||
// WithLocalAddr sets the local address and port
|
||||
func (b *ConnectionBuilder) WithLocalAddr(addr string, port int) *ConnectionBuilder {
|
||||
b.conn.Laddr = addr
|
||||
b.conn.Lport = port
|
||||
return b
|
||||
}
|
||||
|
||||
// WithRemoteAddr sets the remote address and port
|
||||
func (b *ConnectionBuilder) WithRemoteAddr(addr string, port int) *ConnectionBuilder {
|
||||
b.conn.Raddr = addr
|
||||
b.conn.Rport = port
|
||||
return b
|
||||
}
|
||||
|
||||
// WithInterface sets the network interface
|
||||
func (b *ConnectionBuilder) WithInterface(iface string) *ConnectionBuilder {
|
||||
b.conn.Interface = iface
|
||||
return b
|
||||
}
|
||||
|
||||
// WithBytes sets the rx and tx byte counts
|
||||
func (b *ConnectionBuilder) WithBytes(rx, tx int64) *ConnectionBuilder {
|
||||
b.conn.RxBytes = rx
|
||||
b.conn.TxBytes = tx
|
||||
return b
|
||||
}
|
||||
|
||||
// Build returns the constructed connection
|
||||
func (b *ConnectionBuilder) Build() Connection {
|
||||
return b.conn
|
||||
}
|
||||
|
||||
// TestFixture provides test scenarios for different use cases
|
||||
type TestFixture struct {
|
||||
Name string
|
||||
Description string
|
||||
Connections []Connection
|
||||
}
|
||||
|
||||
// GetTestFixtures returns predefined test fixtures for different scenarios
|
||||
func GetTestFixtures() []TestFixture {
|
||||
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
|
||||
return []TestFixture{
|
||||
{
|
||||
Name: "empty",
|
||||
Description: "No connections",
|
||||
Connections: []Connection{},
|
||||
},
|
||||
{
|
||||
Name: "single-tcp",
|
||||
Description: "Single TCP connection",
|
||||
Connections: []Connection{
|
||||
NewConnectionBuilder().
|
||||
WithPID(1234).
|
||||
WithProcess("test-app").
|
||||
WithProto("tcp").
|
||||
WithState("ESTABLISHED").
|
||||
WithLocalAddr("127.0.0.1", 8080).
|
||||
WithRemoteAddr("127.0.0.1", 9090).
|
||||
Build(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed-protocols",
|
||||
Description: "Mix of TCP, UDP, and Unix sockets",
|
||||
Connections: []Connection{
|
||||
{
|
||||
TS: baseTime,
|
||||
PID: 1,
|
||||
Process: "tcp-server",
|
||||
Proto: "tcp",
|
||||
State: "LISTEN",
|
||||
Laddr: "0.0.0.0",
|
||||
Lport: 80,
|
||||
Interface: "eth0",
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(time.Second),
|
||||
PID: 2,
|
||||
Process: "udp-server",
|
||||
Proto: "udp",
|
||||
State: "CONNECTED",
|
||||
Laddr: "0.0.0.0",
|
||||
Lport: 53,
|
||||
Interface: "eth0",
|
||||
},
|
||||
{
|
||||
TS: baseTime.Add(2 * time.Second),
|
||||
PID: 3,
|
||||
Process: "unix-app",
|
||||
Proto: "unix",
|
||||
State: "CONNECTED",
|
||||
Laddr: "/tmp/test.sock",
|
||||
Interface: "unix",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "high-volume",
|
||||
Description: "Large number of connections for performance testing",
|
||||
Connections: generateHighVolumeConnections(1000),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateHighVolumeConnections creates a large number of test connections
|
||||
func generateHighVolumeConnections(count int) []Connection {
|
||||
connections := make([]Connection, count)
|
||||
baseTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
connections[i] = NewConnectionBuilder().
|
||||
WithPID(1000 + i).
|
||||
WithProcess("worker-" + string(rune('a'+i%26))).
|
||||
WithProto([]string{"tcp", "udp"}[i%2]).
|
||||
WithState([]string{"ESTABLISHED", "LISTEN", "TIME_WAIT"}[i%3]).
|
||||
WithLocalAddr("127.0.0.1", 8000+i%1000).
|
||||
WithRemoteAddr("10.0.0."+string(rune('1'+i%10)), 9000+i%1000).
|
||||
Build()
|
||||
connections[i].TS = baseTime.Add(time.Duration(i) * time.Millisecond)
|
||||
}
|
||||
|
||||
return connections
|
||||
}
|
||||
159
internal/collector/query.go
Normal file
159
internal/collector/query.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package collector
|
||||
|
||||
// Query combines filtering, sorting, and limiting into a single operation
|
||||
type Query struct {
|
||||
Filter FilterOptions
|
||||
Sort SortOptions
|
||||
Limit int
|
||||
}
|
||||
|
||||
// NewQuery creates a query with sensible defaults
|
||||
func NewQuery() *Query {
|
||||
return &Query{
|
||||
Filter: FilterOptions{},
|
||||
Sort: SortOptions{Field: SortByLport, Direction: SortAsc},
|
||||
Limit: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilter sets the filter options
|
||||
func (q *Query) WithFilter(f FilterOptions) *Query {
|
||||
q.Filter = f
|
||||
return q
|
||||
}
|
||||
|
||||
// WithSort sets the sort options
|
||||
func (q *Query) WithSort(s SortOptions) *Query {
|
||||
q.Sort = s
|
||||
return q
|
||||
}
|
||||
|
||||
// WithSortString parses and sets sort options from a string like "pid:desc"
|
||||
func (q *Query) WithSortString(s string) *Query {
|
||||
q.Sort = ParseSortOptions(s)
|
||||
return q
|
||||
}
|
||||
|
||||
// WithLimit sets the maximum number of results
|
||||
func (q *Query) WithLimit(n int) *Query {
|
||||
q.Limit = n
|
||||
return q
|
||||
}
|
||||
|
||||
// Proto filters by protocol
|
||||
func (q *Query) Proto(proto string) *Query {
|
||||
q.Filter.Proto = proto
|
||||
return q
|
||||
}
|
||||
|
||||
// State filters by connection state
|
||||
func (q *Query) State(state string) *Query {
|
||||
q.Filter.State = state
|
||||
return q
|
||||
}
|
||||
|
||||
// Process filters by process name (substring match)
|
||||
func (q *Query) Process(proc string) *Query {
|
||||
q.Filter.Proc = proc
|
||||
return q
|
||||
}
|
||||
|
||||
// PID filters by process ID
|
||||
func (q *Query) PID(pid int) *Query {
|
||||
q.Filter.Pid = pid
|
||||
return q
|
||||
}
|
||||
|
||||
// LocalPort filters by local port
|
||||
func (q *Query) LocalPort(port int) *Query {
|
||||
q.Filter.Lport = port
|
||||
return q
|
||||
}
|
||||
|
||||
// RemotePort filters by remote port
|
||||
func (q *Query) RemotePort(port int) *Query {
|
||||
q.Filter.Rport = port
|
||||
return q
|
||||
}
|
||||
|
||||
// IPv4Only filters to only IPv4 connections
|
||||
func (q *Query) IPv4Only() *Query {
|
||||
q.Filter.IPv4 = true
|
||||
q.Filter.IPv6 = false
|
||||
return q
|
||||
}
|
||||
|
||||
// IPv6Only filters to only IPv6 connections
|
||||
func (q *Query) IPv6Only() *Query {
|
||||
q.Filter.IPv4 = false
|
||||
q.Filter.IPv6 = true
|
||||
return q
|
||||
}
|
||||
|
||||
// Listening filters to only listening sockets
|
||||
func (q *Query) Listening() *Query {
|
||||
q.Filter.State = "LISTEN"
|
||||
return q
|
||||
}
|
||||
|
||||
// Established filters to only established connections
|
||||
func (q *Query) Established() *Query {
|
||||
q.Filter.State = "ESTABLISHED"
|
||||
return q
|
||||
}
|
||||
|
||||
// Contains filters by substring in process, local addr, or remote addr
|
||||
func (q *Query) Contains(s string) *Query {
|
||||
q.Filter.Contains = s
|
||||
return q
|
||||
}
|
||||
|
||||
// Execute runs the query and returns results
|
||||
func (q *Query) Execute() ([]Connection, error) {
|
||||
conns, err := GetConnections()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return q.Apply(conns), nil
|
||||
}
|
||||
|
||||
// Apply applies the query to a slice of connections
|
||||
func (q *Query) Apply(conns []Connection) []Connection {
|
||||
result := FilterConnections(conns, q.Filter)
|
||||
SortConnections(result, q.Sort)
|
||||
|
||||
if q.Limit > 0 && len(result) > q.Limit {
|
||||
result = result[:q.Limit]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// common pre-built queries
|
||||
|
||||
// ListeningTCP returns a query for TCP listeners
|
||||
func ListeningTCP() *Query {
|
||||
return NewQuery().Proto("tcp").Listening()
|
||||
}
|
||||
|
||||
// ListeningAll returns a query for all listeners
|
||||
func ListeningAll() *Query {
|
||||
return NewQuery().Listening()
|
||||
}
|
||||
|
||||
// EstablishedTCP returns a query for established TCP connections
|
||||
func EstablishedTCP() *Query {
|
||||
return NewQuery().Proto("tcp").Established()
|
||||
}
|
||||
|
||||
// ByProcess returns a query filtered by process name
|
||||
func ByProcess(name string) *Query {
|
||||
return NewQuery().Process(name)
|
||||
}
|
||||
|
||||
// ByPort returns a query filtered by local port
|
||||
func ByPort(port int) *Query {
|
||||
return NewQuery().LocalPort(port)
|
||||
}
|
||||
|
||||
165
internal/collector/query_test.go
Normal file
165
internal/collector/query_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQueryBuilder(t *testing.T) {
|
||||
t.Run("fluent builder pattern", func(t *testing.T) {
|
||||
q := NewQuery().
|
||||
Proto("tcp").
|
||||
State("LISTEN").
|
||||
WithLimit(10)
|
||||
|
||||
if q.Filter.Proto != "tcp" {
|
||||
t.Errorf("expected proto tcp, got %s", q.Filter.Proto)
|
||||
}
|
||||
if q.Filter.State != "LISTEN" {
|
||||
t.Errorf("expected state LISTEN, got %s", q.Filter.State)
|
||||
}
|
||||
if q.Limit != 10 {
|
||||
t.Errorf("expected limit 10, got %d", q.Limit)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("convenience methods", func(t *testing.T) {
|
||||
q := NewQuery().Listening()
|
||||
if q.Filter.State != "LISTEN" {
|
||||
t.Errorf("Listening() should set state to LISTEN")
|
||||
}
|
||||
|
||||
q = NewQuery().Established()
|
||||
if q.Filter.State != "ESTABLISHED" {
|
||||
t.Errorf("Established() should set state to ESTABLISHED")
|
||||
}
|
||||
|
||||
q = NewQuery().IPv4Only()
|
||||
if q.Filter.IPv4 != true || q.Filter.IPv6 != false {
|
||||
t.Error("IPv4Only() should set IPv4=true, IPv6=false")
|
||||
}
|
||||
|
||||
q = NewQuery().IPv6Only()
|
||||
if q.Filter.IPv4 != false || q.Filter.IPv6 != true {
|
||||
t.Error("IPv6Only() should set IPv4=false, IPv6=true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort options", func(t *testing.T) {
|
||||
q := NewQuery().WithSortString("pid:desc")
|
||||
|
||||
if q.Sort.Field != SortByPID {
|
||||
t.Errorf("expected sort by pid, got %v", q.Sort.Field)
|
||||
}
|
||||
if q.Sort.Direction != SortDesc {
|
||||
t.Errorf("expected sort desc, got %v", q.Sort.Direction)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryApply(t *testing.T) {
|
||||
conns := []Connection{
|
||||
{PID: 1, Process: "nginx", Proto: "tcp", State: "LISTEN", Lport: 80},
|
||||
{PID: 2, Process: "nginx", Proto: "tcp", State: "ESTABLISHED", Lport: 80},
|
||||
{PID: 3, Process: "sshd", Proto: "tcp", State: "LISTEN", Lport: 22},
|
||||
{PID: 4, Process: "postgres", Proto: "tcp", State: "LISTEN", Lport: 5432},
|
||||
{PID: 5, Process: "dnsmasq", Proto: "udp", State: "", Lport: 53},
|
||||
}
|
||||
|
||||
t.Run("filter by state", func(t *testing.T) {
|
||||
q := NewQuery().Listening()
|
||||
result := q.Apply(conns)
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Errorf("expected 3 listening connections, got %d", len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by proto", func(t *testing.T) {
|
||||
q := NewQuery().Proto("udp")
|
||||
result := q.Apply(conns)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 udp connection, got %d", len(result))
|
||||
}
|
||||
if result[0].Process != "dnsmasq" {
|
||||
t.Errorf("expected dnsmasq, got %s", result[0].Process)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter and sort", func(t *testing.T) {
|
||||
q := NewQuery().Listening().WithSortString("lport:asc")
|
||||
result := q.Apply(conns)
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(result))
|
||||
}
|
||||
if result[0].Lport != 22 {
|
||||
t.Errorf("expected port 22 first, got %d", result[0].Lport)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter sort and limit", func(t *testing.T) {
|
||||
q := NewQuery().Proto("tcp").WithSortString("lport:asc").WithLimit(2)
|
||||
result := q.Apply(conns)
|
||||
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected 2 (limit), got %d", len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("process filter substring", func(t *testing.T) {
|
||||
q := NewQuery().Process("nginx")
|
||||
result := q.Apply(conns)
|
||||
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected 2 nginx connections, got %d", len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("contains filter", func(t *testing.T) {
|
||||
q := NewQuery().Contains("post")
|
||||
result := q.Apply(conns)
|
||||
|
||||
if len(result) != 1 || result[0].Process != "postgres" {
|
||||
t.Errorf("expected postgres, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrebuiltQueries(t *testing.T) {
|
||||
t.Run("ListeningTCP", func(t *testing.T) {
|
||||
q := ListeningTCP()
|
||||
if q.Filter.Proto != "tcp" || q.Filter.State != "LISTEN" {
|
||||
t.Error("ListeningTCP should filter tcp + LISTEN")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ListeningAll", func(t *testing.T) {
|
||||
q := ListeningAll()
|
||||
if q.Filter.State != "LISTEN" {
|
||||
t.Error("ListeningAll should filter LISTEN state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EstablishedTCP", func(t *testing.T) {
|
||||
q := EstablishedTCP()
|
||||
if q.Filter.Proto != "tcp" || q.Filter.State != "ESTABLISHED" {
|
||||
t.Error("EstablishedTCP should filter tcp + ESTABLISHED")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ByProcess", func(t *testing.T) {
|
||||
q := ByProcess("nginx")
|
||||
if q.Filter.Proc != "nginx" {
|
||||
t.Error("ByProcess should set process filter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ByPort", func(t *testing.T) {
|
||||
q := ByPort(8080)
|
||||
if q.Filter.Lport != 8080 {
|
||||
t.Error("ByPort should set lport filter")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
131
internal/collector/sort.go
Normal file
131
internal/collector/sort.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SortField represents a field to sort by
|
||||
type SortField string
|
||||
|
||||
const (
|
||||
SortByPID SortField = "pid"
|
||||
SortByProcess SortField = "process"
|
||||
SortByUser SortField = "user"
|
||||
SortByProto SortField = "proto"
|
||||
SortByState SortField = "state"
|
||||
SortByLaddr SortField = "laddr"
|
||||
SortByLport SortField = "lport"
|
||||
SortByRaddr SortField = "raddr"
|
||||
SortByRport SortField = "rport"
|
||||
SortByInterface SortField = "if"
|
||||
SortByRxBytes SortField = "rx_bytes"
|
||||
SortByTxBytes SortField = "tx_bytes"
|
||||
SortByRttMs SortField = "rtt_ms"
|
||||
SortByTimestamp SortField = "ts"
|
||||
)
|
||||
|
||||
// SortDirection represents ascending or descending order
|
||||
type SortDirection int
|
||||
|
||||
const (
|
||||
SortAsc SortDirection = iota
|
||||
SortDesc
|
||||
)
|
||||
|
||||
// SortOptions configures how connections are sorted
|
||||
type SortOptions struct {
|
||||
Field SortField
|
||||
Direction SortDirection
|
||||
}
|
||||
|
||||
// ParseSortOptions parses a sort string like "pid:desc" or "lport"
|
||||
func ParseSortOptions(s string) SortOptions {
|
||||
if s == "" {
|
||||
return SortOptions{Field: SortByLport, Direction: SortAsc}
|
||||
}
|
||||
|
||||
parts := strings.SplitN(s, ":", 2)
|
||||
field := SortField(strings.ToLower(parts[0]))
|
||||
direction := SortAsc
|
||||
|
||||
if len(parts) > 1 && strings.ToLower(parts[1]) == "desc" {
|
||||
direction = SortDesc
|
||||
}
|
||||
|
||||
return SortOptions{Field: field, Direction: direction}
|
||||
}
|
||||
|
||||
// SortConnections sorts a slice of connections in place
|
||||
func SortConnections(conns []Connection, opts SortOptions) {
|
||||
if len(conns) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
sort.SliceStable(conns, func(i, j int) bool {
|
||||
less := compareConnections(conns[i], conns[j], opts.Field)
|
||||
if opts.Direction == SortDesc {
|
||||
return !less
|
||||
}
|
||||
return less
|
||||
})
|
||||
}
|
||||
|
||||
func compareConnections(a, b Connection, field SortField) bool {
|
||||
switch field {
|
||||
case SortByPID:
|
||||
return a.PID < b.PID
|
||||
case SortByProcess:
|
||||
return strings.ToLower(a.Process) < strings.ToLower(b.Process)
|
||||
case SortByUser:
|
||||
return strings.ToLower(a.User) < strings.ToLower(b.User)
|
||||
case SortByProto:
|
||||
return a.Proto < b.Proto
|
||||
case SortByState:
|
||||
return stateOrder(a.State) < stateOrder(b.State)
|
||||
case SortByLaddr:
|
||||
return a.Laddr < b.Laddr
|
||||
case SortByLport:
|
||||
return a.Lport < b.Lport
|
||||
case SortByRaddr:
|
||||
return a.Raddr < b.Raddr
|
||||
case SortByRport:
|
||||
return a.Rport < b.Rport
|
||||
case SortByInterface:
|
||||
return a.Interface < b.Interface
|
||||
case SortByRxBytes:
|
||||
return a.RxBytes < b.RxBytes
|
||||
case SortByTxBytes:
|
||||
return a.TxBytes < b.TxBytes
|
||||
case SortByRttMs:
|
||||
return a.RttMs < b.RttMs
|
||||
case SortByTimestamp:
|
||||
return a.TS.Before(b.TS)
|
||||
default:
|
||||
return a.Lport < b.Lport
|
||||
}
|
||||
}
|
||||
|
||||
// stateOrder returns a numeric order for connection states
|
||||
// puts LISTEN first, then ESTABLISHED, then others
|
||||
func stateOrder(state string) int {
|
||||
order := map[string]int{
|
||||
"LISTEN": 0,
|
||||
"ESTABLISHED": 1,
|
||||
"SYN_SENT": 2,
|
||||
"SYN_RECV": 3,
|
||||
"FIN_WAIT1": 4,
|
||||
"FIN_WAIT2": 5,
|
||||
"TIME_WAIT": 6,
|
||||
"CLOSE_WAIT": 7,
|
||||
"LAST_ACK": 8,
|
||||
"CLOSING": 9,
|
||||
"CLOSED": 10,
|
||||
}
|
||||
|
||||
if o, exists := order[strings.ToUpper(state)]; exists {
|
||||
return o
|
||||
}
|
||||
return 99
|
||||
}
|
||||
|
||||
130
internal/collector/sort_test.go
Normal file
130
internal/collector/sort_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSortConnections(t *testing.T) {
|
||||
conns := []Connection{
|
||||
{PID: 3, Process: "nginx", Lport: 80, State: "ESTABLISHED"},
|
||||
{PID: 1, Process: "sshd", Lport: 22, State: "LISTEN"},
|
||||
{PID: 2, Process: "postgres", Lport: 5432, State: "LISTEN"},
|
||||
}
|
||||
|
||||
t.Run("sort by PID ascending", func(t *testing.T) {
|
||||
c := make([]Connection, len(conns))
|
||||
copy(c, conns)
|
||||
|
||||
SortConnections(c, SortOptions{Field: SortByPID, Direction: SortAsc})
|
||||
|
||||
if c[0].PID != 1 || c[1].PID != 2 || c[2].PID != 3 {
|
||||
t.Errorf("expected PIDs [1,2,3], got [%d,%d,%d]", c[0].PID, c[1].PID, c[2].PID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by PID descending", func(t *testing.T) {
|
||||
c := make([]Connection, len(conns))
|
||||
copy(c, conns)
|
||||
|
||||
SortConnections(c, SortOptions{Field: SortByPID, Direction: SortDesc})
|
||||
|
||||
if c[0].PID != 3 || c[1].PID != 2 || c[2].PID != 1 {
|
||||
t.Errorf("expected PIDs [3,2,1], got [%d,%d,%d]", c[0].PID, c[1].PID, c[2].PID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by port ascending", func(t *testing.T) {
|
||||
c := make([]Connection, len(conns))
|
||||
copy(c, conns)
|
||||
|
||||
SortConnections(c, SortOptions{Field: SortByLport, Direction: SortAsc})
|
||||
|
||||
if c[0].Lport != 22 || c[1].Lport != 80 || c[2].Lport != 5432 {
|
||||
t.Errorf("expected ports [22,80,5432], got [%d,%d,%d]", c[0].Lport, c[1].Lport, c[2].Lport)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by state puts LISTEN first", func(t *testing.T) {
|
||||
c := make([]Connection, len(conns))
|
||||
copy(c, conns)
|
||||
|
||||
SortConnections(c, SortOptions{Field: SortByState, Direction: SortAsc})
|
||||
|
||||
if c[0].State != "LISTEN" || c[1].State != "LISTEN" || c[2].State != "ESTABLISHED" {
|
||||
t.Errorf("expected LISTEN states first, got [%s,%s,%s]", c[0].State, c[1].State, c[2].State)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by process case insensitive", func(t *testing.T) {
|
||||
c := []Connection{
|
||||
{Process: "Nginx"},
|
||||
{Process: "apache"},
|
||||
{Process: "SSHD"},
|
||||
}
|
||||
|
||||
SortConnections(c, SortOptions{Field: SortByProcess, Direction: SortAsc})
|
||||
|
||||
if c[0].Process != "apache" {
|
||||
t.Errorf("expected apache first, got %s", c[0].Process)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseSortOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantField SortField
|
||||
wantDir SortDirection
|
||||
}{
|
||||
{"pid", SortByPID, SortAsc},
|
||||
{"pid:asc", SortByPID, SortAsc},
|
||||
{"pid:desc", SortByPID, SortDesc},
|
||||
{"lport", SortByLport, SortAsc},
|
||||
{"LPORT:DESC", SortByLport, SortDesc},
|
||||
{"", SortByLport, SortAsc}, // default
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
opts := ParseSortOptions(tt.input)
|
||||
if opts.Field != tt.wantField {
|
||||
t.Errorf("field: got %v, want %v", opts.Field, tt.wantField)
|
||||
}
|
||||
if opts.Direction != tt.wantDir {
|
||||
t.Errorf("direction: got %v, want %v", opts.Direction, tt.wantDir)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateOrder(t *testing.T) {
|
||||
if stateOrder("LISTEN") >= stateOrder("ESTABLISHED") {
|
||||
t.Error("LISTEN should come before ESTABLISHED")
|
||||
}
|
||||
if stateOrder("ESTABLISHED") >= stateOrder("TIME_WAIT") {
|
||||
t.Error("ESTABLISHED should come before TIME_WAIT")
|
||||
}
|
||||
if stateOrder("UNKNOWN") != 99 {
|
||||
t.Error("unknown states should return 99")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortByTimestamp(t *testing.T) {
|
||||
now := time.Now()
|
||||
conns := []Connection{
|
||||
{TS: now.Add(2 * time.Second)},
|
||||
{TS: now},
|
||||
{TS: now.Add(1 * time.Second)},
|
||||
}
|
||||
|
||||
SortConnections(conns, SortOptions{Field: SortByTimestamp, Direction: SortAsc})
|
||||
|
||||
if !conns[0].TS.Equal(now) {
|
||||
t.Error("oldest timestamp should be first")
|
||||
}
|
||||
if !conns[2].TS.Equal(now.Add(2 * time.Second)) {
|
||||
t.Error("newest timestamp should be last")
|
||||
}
|
||||
}
|
||||
|
||||
25
internal/collector/types.go
Normal file
25
internal/collector/types.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package collector
|
||||
|
||||
import "time"
|
||||
|
||||
type Connection struct {
|
||||
TS time.Time `json:"ts"`
|
||||
PID int `json:"pid"`
|
||||
Process string `json:"process"`
|
||||
User string `json:"user"`
|
||||
UID int `json:"uid"`
|
||||
Proto string `json:"proto"`
|
||||
IPVersion string `json:"ipversion"`
|
||||
State string `json:"state"`
|
||||
Laddr string `json:"laddr"`
|
||||
Lport int `json:"lport"`
|
||||
Raddr string `json:"raddr"`
|
||||
Rport int `json:"rport"`
|
||||
Interface string `json:"interface"`
|
||||
RxBytes int64 `json:"rx_bytes"`
|
||||
TxBytes int64 `json:"tx_bytes"`
|
||||
RttMs float64 `json:"rtt_ms"`
|
||||
Mark string `json:"mark"`
|
||||
Namespace string `json:"namespace"`
|
||||
Inode int64 `json:"inode"`
|
||||
}
|
||||
60
internal/color/color.go
Normal file
60
internal/color/color.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package color
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
var (
|
||||
Header = color.New(color.FgGreen, color.Bold)
|
||||
Bold = color.New(color.Bold)
|
||||
Faint = color.New(color.Faint)
|
||||
TCP = color.New(color.FgCyan)
|
||||
UDP = color.New(color.FgMagenta)
|
||||
LISTEN = color.New(color.FgYellow)
|
||||
ESTABLISHED = color.New(color.FgGreen)
|
||||
Default = color.New(color.FgWhite)
|
||||
)
|
||||
|
||||
func Init(mode string) {
|
||||
switch mode {
|
||||
case "always":
|
||||
color.NoColor = false
|
||||
case "never":
|
||||
color.NoColor = true
|
||||
case "auto":
|
||||
if os.Getenv("NO_COLOR") != "" || os.Getenv("TERM") == "dumb" {
|
||||
color.NoColor = true
|
||||
} else {
|
||||
color.NoColor = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IsColorDisabled() bool {
|
||||
return color.NoColor
|
||||
}
|
||||
|
||||
func GetProtoColor(proto string) *color.Color {
|
||||
switch strings.ToLower(proto) {
|
||||
case "tcp":
|
||||
return TCP
|
||||
case "udp":
|
||||
return UDP
|
||||
default:
|
||||
return Default
|
||||
}
|
||||
}
|
||||
|
||||
func GetStateColor(state string) *color.Color {
|
||||
switch strings.ToUpper(state) {
|
||||
case "LISTEN", "LISTENING":
|
||||
return LISTEN
|
||||
case "ESTABLISHED":
|
||||
return ESTABLISHED
|
||||
default:
|
||||
return Default
|
||||
}
|
||||
}
|
||||
46
internal/color/color_test.go
Normal file
46
internal/color/color_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package color
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mode string
|
||||
noColor string
|
||||
term string
|
||||
expected bool
|
||||
}{
|
||||
{"Always", "always", "", "", false},
|
||||
{"Never", "never", "", "", true},
|
||||
{"Auto no env", "auto", "", "xterm-256color", false},
|
||||
{"Auto with NO_COLOR", "auto", "1", "xterm-256color", true},
|
||||
{"Auto with TERM=dumb", "auto", "", "dumb", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Save original env vars
|
||||
origNoColor := os.Getenv("NO_COLOR")
|
||||
origTerm := os.Getenv("TERM")
|
||||
|
||||
// Set test env vars
|
||||
os.Setenv("NO_COLOR", tc.noColor)
|
||||
os.Setenv("TERM", tc.term)
|
||||
|
||||
Init(tc.mode)
|
||||
|
||||
if color.NoColor != tc.expected {
|
||||
t.Errorf("Expected color.NoColor to be %v, but got %v", tc.expected, color.NoColor)
|
||||
}
|
||||
|
||||
// Restore original env vars
|
||||
os.Setenv("NO_COLOR", origNoColor)
|
||||
os.Setenv("TERM", origTerm)
|
||||
})
|
||||
}
|
||||
}
|
||||
203
internal/config/config.go
Normal file
203
internal/config/config.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config represents the application configuration
|
||||
type Config struct {
|
||||
Defaults DefaultConfig `mapstructure:"defaults"`
|
||||
}
|
||||
|
||||
// DefaultConfig contains default values for CLI options
|
||||
type DefaultConfig struct {
|
||||
Interval string `mapstructure:"interval"`
|
||||
Numeric bool `mapstructure:"numeric"`
|
||||
Fields []string `mapstructure:"fields"`
|
||||
Theme string `mapstructure:"theme"`
|
||||
Units string `mapstructure:"units"`
|
||||
Color string `mapstructure:"color"`
|
||||
Resolve bool `mapstructure:"resolve"`
|
||||
IPv4 bool `mapstructure:"ipv4"`
|
||||
IPv6 bool `mapstructure:"ipv6"`
|
||||
NoHeaders bool `mapstructure:"no_headers"`
|
||||
OutputFormat string `mapstructure:"output_format"`
|
||||
SortBy string `mapstructure:"sort_by"`
|
||||
}
|
||||
|
||||
var globalConfig *Config
|
||||
|
||||
// Load loads configuration from file and environment variables
|
||||
func Load() (*Config, error) {
|
||||
if globalConfig != nil {
|
||||
return globalConfig, nil
|
||||
}
|
||||
|
||||
v := viper.New()
|
||||
|
||||
// set config name and file type (auto-detect based on extension)
|
||||
v.SetConfigName("snitch")
|
||||
// don't set config type - let viper auto-detect based on file extension
|
||||
// this allows both .toml and .yaml files to work
|
||||
v.AddConfigPath("$HOME/.config/snitch")
|
||||
v.AddConfigPath("$HOME/.snitch")
|
||||
v.AddConfigPath("/etc/snitch")
|
||||
|
||||
// Environment variables
|
||||
v.SetEnvPrefix("SNITCH")
|
||||
v.AutomaticEnv()
|
||||
|
||||
// environment variable bindings for readme-documented variables
|
||||
_ = v.BindEnv("config", "SNITCH_CONFIG")
|
||||
_ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE")
|
||||
_ = v.BindEnv("defaults.theme", "SNITCH_THEME")
|
||||
_ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR")
|
||||
|
||||
// Set defaults
|
||||
setDefaults(v)
|
||||
|
||||
// Handle SNITCH_CONFIG environment variable for custom config path
|
||||
if configPath := os.Getenv("SNITCH_CONFIG"); configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
|
||||
// Try to read config file
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
// It's OK if config file doesn't exist
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("error reading config file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle special environment variables
|
||||
handleSpecialEnvVars(v)
|
||||
|
||||
// Unmarshal into config struct
|
||||
config := &Config{}
|
||||
if err := v.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling config: %w", err)
|
||||
}
|
||||
|
||||
globalConfig = config
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func setDefaults(v *viper.Viper) {
|
||||
// Set default values matching the README specification
|
||||
v.SetDefault("defaults.interval", "1s")
|
||||
v.SetDefault("defaults.numeric", false)
|
||||
v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"})
|
||||
v.SetDefault("defaults.theme", "auto")
|
||||
v.SetDefault("defaults.units", "auto")
|
||||
v.SetDefault("defaults.color", "auto")
|
||||
v.SetDefault("defaults.resolve", true)
|
||||
v.SetDefault("defaults.ipv4", false)
|
||||
v.SetDefault("defaults.ipv6", false)
|
||||
v.SetDefault("defaults.no_headers", false)
|
||||
v.SetDefault("defaults.output_format", "table")
|
||||
v.SetDefault("defaults.sort_by", "")
|
||||
}
|
||||
|
||||
func handleSpecialEnvVars(v *viper.Viper) {
|
||||
// Handle SNITCH_NO_COLOR - if set to "1", disable color
|
||||
if os.Getenv("SNITCH_NO_COLOR") == "1" {
|
||||
v.Set("defaults.color", "never")
|
||||
}
|
||||
|
||||
// Handle SNITCH_RESOLVE - if set to "0", disable resolution
|
||||
if os.Getenv("SNITCH_RESOLVE") == "0" {
|
||||
v.Set("defaults.resolve", false)
|
||||
v.Set("defaults.numeric", true)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the global configuration, loading it if necessary
|
||||
func Get() *Config {
|
||||
if globalConfig == nil {
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
// Return default config on error
|
||||
return &Config{
|
||||
Defaults: DefaultConfig{
|
||||
Interval: "1s",
|
||||
Numeric: false,
|
||||
Fields: []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"},
|
||||
Theme: "auto",
|
||||
Units: "auto",
|
||||
Color: "auto",
|
||||
Resolve: true,
|
||||
IPv4: false,
|
||||
IPv6: false,
|
||||
NoHeaders: false,
|
||||
OutputFormat: "table",
|
||||
SortBy: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
return globalConfig
|
||||
}
|
||||
|
||||
// GetInterval returns the configured interval as a duration
|
||||
func (c *Config) GetInterval() time.Duration {
|
||||
if duration, err := time.ParseDuration(c.Defaults.Interval); err == nil {
|
||||
return duration
|
||||
}
|
||||
return time.Second // default fallback
|
||||
}
|
||||
|
||||
// CreateExampleConfig creates an example configuration file
|
||||
func CreateExampleConfig(path string) error {
|
||||
exampleConfig := `# snitch configuration file
|
||||
# See https://github.com/you/snitch for full documentation
|
||||
|
||||
[defaults]
|
||||
# Default refresh interval for watch/stats/trace commands
|
||||
interval = "1s"
|
||||
|
||||
# Disable name/service resolution by default
|
||||
numeric = false
|
||||
|
||||
# Default fields to display (comma-separated list)
|
||||
fields = ["pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"]
|
||||
|
||||
# Default theme for TUI (dark, light, mono, auto)
|
||||
theme = "auto"
|
||||
|
||||
# Default units for byte display (auto, si, iec)
|
||||
units = "auto"
|
||||
|
||||
# Default color mode (auto, always, never)
|
||||
color = "auto"
|
||||
|
||||
# Enable name resolution by default
|
||||
resolve = true
|
||||
|
||||
# Filter options
|
||||
ipv4 = false
|
||||
ipv6 = false
|
||||
|
||||
# Output options
|
||||
no_headers = false
|
||||
output_format = "table"
|
||||
sort_by = ""
|
||||
`
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// Write config file
|
||||
if err := os.WriteFile(path, []byte(exampleConfig), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
191
internal/resolver/resolver.go
Normal file
191
internal/resolver/resolver.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Resolver handles DNS and service name resolution with caching and timeouts
|
||||
type Resolver struct {
|
||||
timeout time.Duration
|
||||
cache map[string]string
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new resolver with the specified timeout
|
||||
func New(timeout time.Duration) *Resolver {
|
||||
return &Resolver{
|
||||
timeout: timeout,
|
||||
cache: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveAddr resolves an IP address to a hostname, with caching
|
||||
func (r *Resolver) ResolveAddr(addr string) string {
|
||||
// Check cache first
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[addr]; exists {
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
// Parse IP to validate it
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
// Not a valid IP, return as-is
|
||||
return addr
|
||||
}
|
||||
|
||||
// Perform resolution with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
|
||||
|
||||
resolved := addr // fallback to original address
|
||||
if err == nil && len(names) > 0 {
|
||||
resolved = names[0]
|
||||
// Remove trailing dot if present
|
||||
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
|
||||
resolved = resolved[:len(resolved)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
r.mutex.Lock()
|
||||
r.cache[addr] = resolved
|
||||
r.mutex.Unlock()
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
// ResolvePort resolves a port number to a service name
|
||||
func (r *Resolver) ResolvePort(port int, proto string) string {
|
||||
if port == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
cacheKey := strconv.Itoa(port) + "/" + proto
|
||||
|
||||
// Check cache first
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[cacheKey]; exists {
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
// Perform resolution with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
service, err := net.DefaultResolver.LookupPort(ctx, proto, strconv.Itoa(port))
|
||||
|
||||
resolved := strconv.Itoa(port) // fallback to port number
|
||||
if err == nil && service != 0 {
|
||||
// Try to get service name
|
||||
if serviceName := getServiceName(port, proto); serviceName != "" {
|
||||
resolved = serviceName
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
r.mutex.Lock()
|
||||
r.cache[cacheKey] = resolved
|
||||
r.mutex.Unlock()
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
// ResolveAddrPort resolves both address and port
|
||||
func (r *Resolver) ResolveAddrPort(addr string, port int, proto string) (string, string) {
|
||||
resolvedAddr := r.ResolveAddr(addr)
|
||||
resolvedPort := r.ResolvePort(port, proto)
|
||||
return resolvedAddr, resolvedPort
|
||||
}
|
||||
|
||||
// ClearCache clears the resolution cache
|
||||
func (r *Resolver) ClearCache() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.cache = make(map[string]string)
|
||||
}
|
||||
|
||||
// GetCacheSize returns the number of cached entries
|
||||
func (r *Resolver) GetCacheSize() int {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return len(r.cache)
|
||||
}
|
||||
|
||||
// getServiceName returns well-known service names for common ports
|
||||
func getServiceName(port int, proto string) string {
|
||||
// Common services - this could be expanded or loaded from /etc/services
|
||||
services := map[string]string{
|
||||
"80/tcp": "http",
|
||||
"443/tcp": "https",
|
||||
"22/tcp": "ssh",
|
||||
"21/tcp": "ftp",
|
||||
"25/tcp": "smtp",
|
||||
"53/tcp": "domain",
|
||||
"53/udp": "domain",
|
||||
"110/tcp": "pop3",
|
||||
"143/tcp": "imap",
|
||||
"993/tcp": "imaps",
|
||||
"995/tcp": "pop3s",
|
||||
"3306/tcp": "mysql",
|
||||
"5432/tcp": "postgresql",
|
||||
"6379/tcp": "redis",
|
||||
"3389/tcp": "rdp",
|
||||
"5900/tcp": "vnc",
|
||||
"23/tcp": "telnet",
|
||||
"69/udp": "tftp",
|
||||
"123/udp": "ntp",
|
||||
"161/udp": "snmp",
|
||||
"514/udp": "syslog",
|
||||
"67/udp": "bootps",
|
||||
"68/udp": "bootpc",
|
||||
}
|
||||
|
||||
key := strconv.Itoa(port) + "/" + proto
|
||||
if service, exists := services[key]; exists {
|
||||
return service
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Global resolver instance
|
||||
var globalResolver *Resolver
|
||||
|
||||
// SetGlobalResolver sets the global resolver instance
|
||||
func SetGlobalResolver(timeout time.Duration) {
|
||||
globalResolver = New(timeout)
|
||||
}
|
||||
|
||||
// GetGlobalResolver returns the global resolver instance
|
||||
func GetGlobalResolver() *Resolver {
|
||||
if globalResolver == nil {
|
||||
globalResolver = New(200 * time.Millisecond) // Default timeout
|
||||
}
|
||||
return globalResolver
|
||||
}
|
||||
|
||||
// ResolveAddr is a convenience function using the global resolver
|
||||
func ResolveAddr(addr string) string {
|
||||
return GetGlobalResolver().ResolveAddr(addr)
|
||||
}
|
||||
|
||||
// ResolvePort is a convenience function using the global resolver
|
||||
func ResolvePort(port int, proto string) string {
|
||||
return GetGlobalResolver().ResolvePort(port, proto)
|
||||
}
|
||||
|
||||
// ResolveAddrPort is a convenience function using the global resolver
|
||||
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
|
||||
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
|
||||
}
|
||||
215
internal/testutil/testutil.go
Normal file
215
internal/testutil/testutil.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"snitch/internal/collector"
|
||||
)
|
||||
|
||||
// TestCollector wraps MockCollector for use in tests
|
||||
type TestCollector struct {
|
||||
*collector.MockCollector
|
||||
}
|
||||
|
||||
// NewTestCollector creates a new test collector with default data
|
||||
func NewTestCollector() *TestCollector {
|
||||
return &TestCollector{
|
||||
MockCollector: collector.NewMockCollector(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestCollectorWithFixture creates a test collector with a specific fixture
|
||||
func NewTestCollectorWithFixture(fixtureName string) *TestCollector {
|
||||
fixtures := collector.GetTestFixtures()
|
||||
for _, fixture := range fixtures {
|
||||
if fixture.Name == fixtureName {
|
||||
mock := collector.NewMockCollector()
|
||||
mock.SetConnections(fixture.Connections)
|
||||
return &TestCollector{MockCollector: mock}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to default if fixture not found
|
||||
return NewTestCollector()
|
||||
}
|
||||
|
||||
// SetupTestEnvironment sets up a clean test environment
|
||||
func SetupTestEnvironment(t *testing.T) (string, func()) {
|
||||
// Create temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "snitch-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
// Set test environment variables
|
||||
oldConfig := os.Getenv("SNITCH_CONFIG")
|
||||
oldNoColor := os.Getenv("SNITCH_NO_COLOR")
|
||||
|
||||
os.Setenv("SNITCH_NO_COLOR", "1") // Disable colors in tests
|
||||
|
||||
// Cleanup function
|
||||
cleanup := func() {
|
||||
os.RemoveAll(tempDir)
|
||||
os.Setenv("SNITCH_CONFIG", oldConfig)
|
||||
os.Setenv("SNITCH_NO_COLOR", oldNoColor)
|
||||
}
|
||||
|
||||
return tempDir, cleanup
|
||||
}
|
||||
|
||||
// CreateFixtureFile creates a JSON fixture file with the given connections
|
||||
func CreateFixtureFile(t *testing.T, dir string, name string, connections []collector.Connection) string {
|
||||
mock := collector.NewMockCollector()
|
||||
mock.SetConnections(connections)
|
||||
|
||||
filePath := filepath.Join(dir, name+".json")
|
||||
if err := mock.SaveToFile(filePath); err != nil {
|
||||
t.Fatalf("Failed to create fixture file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
return filePath
|
||||
}
|
||||
|
||||
// LoadFixtureFile loads connections from a JSON fixture file
|
||||
func LoadFixtureFile(t *testing.T, filePath string) []collector.Connection {
|
||||
mock, err := collector.NewMockCollectorFromFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load fixture file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
connections, err := mock.GetConnections()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connections from fixture: %v", err)
|
||||
}
|
||||
|
||||
return connections
|
||||
}
|
||||
|
||||
// AssertConnectionsEqual compares two slices of connections for equality
|
||||
func AssertConnectionsEqual(t *testing.T, expected, actual []collector.Connection) {
|
||||
if len(expected) != len(actual) {
|
||||
t.Errorf("Connection count mismatch: expected %d, got %d", len(expected), len(actual))
|
||||
return
|
||||
}
|
||||
|
||||
for i, exp := range expected {
|
||||
act := actual[i]
|
||||
|
||||
// Compare key fields (timestamps may vary slightly)
|
||||
if exp.PID != act.PID {
|
||||
t.Errorf("Connection %d PID mismatch: expected %d, got %d", i, exp.PID, act.PID)
|
||||
}
|
||||
if exp.Process != act.Process {
|
||||
t.Errorf("Connection %d Process mismatch: expected %s, got %s", i, exp.Process, act.Process)
|
||||
}
|
||||
if exp.Proto != act.Proto {
|
||||
t.Errorf("Connection %d Proto mismatch: expected %s, got %s", i, exp.Proto, act.Proto)
|
||||
}
|
||||
if exp.State != act.State {
|
||||
t.Errorf("Connection %d State mismatch: expected %s, got %s", i, exp.State, act.State)
|
||||
}
|
||||
if exp.Laddr != act.Laddr {
|
||||
t.Errorf("Connection %d Laddr mismatch: expected %s, got %s", i, exp.Laddr, act.Laddr)
|
||||
}
|
||||
if exp.Lport != act.Lport {
|
||||
t.Errorf("Connection %d Lport mismatch: expected %d, got %d", i, exp.Lport, act.Lport)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTestConfig returns a test configuration with safe defaults
|
||||
func GetTestConfig() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"defaults": map[string]interface{}{
|
||||
"interval": "1s",
|
||||
"numeric": true, // Disable resolution in tests
|
||||
"fields": []string{"pid", "process", "proto", "state", "laddr", "lport"},
|
||||
"theme": "mono", // Use monochrome theme in tests
|
||||
"units": "auto",
|
||||
"color": "never",
|
||||
"resolve": false,
|
||||
"ipv4": false,
|
||||
"ipv6": false,
|
||||
"no_headers": false,
|
||||
"output_format": "table",
|
||||
"sort_by": "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CaptureOutput captures stdout/stderr during test execution
|
||||
type OutputCapture struct {
|
||||
stdout *os.File
|
||||
stderr *os.File
|
||||
oldStdout *os.File
|
||||
oldStderr *os.File
|
||||
stdoutFile string
|
||||
stderrFile string
|
||||
}
|
||||
|
||||
// NewOutputCapture creates a new output capture
|
||||
func NewOutputCapture(t *testing.T) *OutputCapture {
|
||||
tempDir, err := os.MkdirTemp("", "snitch-output-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir for output capture: %v", err)
|
||||
}
|
||||
|
||||
stdoutFile := filepath.Join(tempDir, "stdout")
|
||||
stderrFile := filepath.Join(tempDir, "stderr")
|
||||
|
||||
stdout, err := os.Create(stdoutFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create stdout file: %v", err)
|
||||
}
|
||||
|
||||
stderr, err := os.Create(stderrFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create stderr file: %v", err)
|
||||
}
|
||||
|
||||
return &OutputCapture{
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
oldStdout: os.Stdout,
|
||||
oldStderr: os.Stderr,
|
||||
stdoutFile: stdoutFile,
|
||||
stderrFile: stderrFile,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins capturing output
|
||||
func (oc *OutputCapture) Start() {
|
||||
os.Stdout = oc.stdout
|
||||
os.Stderr = oc.stderr
|
||||
}
|
||||
|
||||
// Stop stops capturing and returns the captured output
|
||||
func (oc *OutputCapture) Stop() (string, string, error) {
|
||||
// Restore original stdout/stderr
|
||||
os.Stdout = oc.oldStdout
|
||||
os.Stderr = oc.oldStderr
|
||||
|
||||
// Close files
|
||||
oc.stdout.Close()
|
||||
oc.stderr.Close()
|
||||
|
||||
// Read captured content
|
||||
stdoutContent, err := os.ReadFile(oc.stdoutFile)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
stderrContent, err := os.ReadFile(oc.stderrFile)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
os.Remove(oc.stdoutFile)
|
||||
os.Remove(oc.stderrFile)
|
||||
os.Remove(filepath.Dir(oc.stdoutFile))
|
||||
|
||||
return string(stdoutContent), string(stderrContent), nil
|
||||
}
|
||||
247
internal/theme/theme.go
Normal file
247
internal/theme/theme.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package theme
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Theme represents the visual styling for the TUI
|
||||
type Theme struct {
|
||||
Name string
|
||||
Styles Styles
|
||||
}
|
||||
|
||||
// Styles contains all the styling definitions
|
||||
type Styles struct {
|
||||
Header lipgloss.Style
|
||||
Border lipgloss.Style
|
||||
Selected lipgloss.Style
|
||||
Watched lipgloss.Style
|
||||
Normal lipgloss.Style
|
||||
Error lipgloss.Style
|
||||
Success lipgloss.Style
|
||||
Warning lipgloss.Style
|
||||
Proto ProtoStyles
|
||||
State StateStyles
|
||||
Footer lipgloss.Style
|
||||
Background lipgloss.Style
|
||||
}
|
||||
|
||||
// ProtoStyles contains protocol-specific colors
|
||||
type ProtoStyles struct {
|
||||
TCP lipgloss.Style
|
||||
UDP lipgloss.Style
|
||||
Unix lipgloss.Style
|
||||
TCP6 lipgloss.Style
|
||||
UDP6 lipgloss.Style
|
||||
}
|
||||
|
||||
// StateStyles contains connection state-specific colors
|
||||
type StateStyles struct {
|
||||
Listen lipgloss.Style
|
||||
Established lipgloss.Style
|
||||
TimeWait lipgloss.Style
|
||||
CloseWait lipgloss.Style
|
||||
SynSent lipgloss.Style
|
||||
SynRecv lipgloss.Style
|
||||
FinWait1 lipgloss.Style
|
||||
FinWait2 lipgloss.Style
|
||||
Closing lipgloss.Style
|
||||
LastAck lipgloss.Style
|
||||
Closed lipgloss.Style
|
||||
}
|
||||
|
||||
var (
|
||||
themes map[string]*Theme
|
||||
)
|
||||
|
||||
func init() {
|
||||
themes = map[string]*Theme{
|
||||
"default": createAdaptiveTheme(),
|
||||
"mono": createMonoTheme(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetTheme returns a theme by name, with auto-detection support
|
||||
func GetTheme(name string) *Theme {
|
||||
if name == "auto" {
|
||||
// lipgloss handles adaptive colors, so we just return the default
|
||||
return themes["default"]
|
||||
}
|
||||
|
||||
if theme, exists := themes[name]; exists {
|
||||
return theme
|
||||
}
|
||||
|
||||
// a specific theme was requested (e.g. "dark", "light"), but we now use adaptive
|
||||
// so we can just return the default theme and lipgloss will handle it
|
||||
if name == "dark" || name == "light" {
|
||||
return themes["default"]
|
||||
}
|
||||
|
||||
// fallback to default
|
||||
return themes["default"]
|
||||
}
|
||||
|
||||
// ListThemes returns available theme names
|
||||
func ListThemes() []string {
|
||||
var names []string
|
||||
for name := range themes {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// createAdaptiveTheme creates a clean, minimal theme
|
||||
func createAdaptiveTheme() *Theme {
|
||||
return &Theme{
|
||||
Name: "default",
|
||||
Styles: Styles{
|
||||
Header: lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#1F2937", Dark: "#F9FAFB"}),
|
||||
|
||||
Watched: lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#F59E0B"}),
|
||||
|
||||
Border: lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#D1D5DB", Dark: "#374151"}),
|
||||
|
||||
Selected: lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#1F2937", Dark: "#F9FAFB"}),
|
||||
|
||||
Normal: lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#6B7280", Dark: "#9CA3AF"}),
|
||||
|
||||
Error: lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
|
||||
|
||||
Success: lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
|
||||
|
||||
Warning: lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
|
||||
|
||||
Footer: lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#9CA3AF", Dark: "#6B7280"}),
|
||||
|
||||
Background: lipgloss.NewStyle(),
|
||||
|
||||
Proto: ProtoStyles{
|
||||
TCP: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
|
||||
UDP: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
|
||||
Unix: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#6B7280", Dark: "#9CA3AF"}),
|
||||
TCP6: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
|
||||
UDP6: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
|
||||
},
|
||||
|
||||
State: StateStyles{
|
||||
Listen: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
|
||||
Established: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#2563EB", Dark: "#60A5FA"}),
|
||||
TimeWait: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
|
||||
CloseWait: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
|
||||
SynSent: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
|
||||
SynRecv: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
|
||||
FinWait1: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
|
||||
FinWait2: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
|
||||
Closing: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
|
||||
LastAck: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
|
||||
Closed: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#9CA3AF", Dark: "#6B7280"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// createMonoTheme creates a monochrome theme (no colors)
|
||||
func createMonoTheme() *Theme {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
boldStyle := lipgloss.NewStyle().Bold(true)
|
||||
|
||||
return &Theme{
|
||||
Name: "mono",
|
||||
Styles: Styles{
|
||||
Header: boldStyle,
|
||||
Border: baseStyle,
|
||||
Selected: boldStyle,
|
||||
Normal: baseStyle,
|
||||
Error: boldStyle,
|
||||
Success: boldStyle,
|
||||
Warning: boldStyle,
|
||||
Footer: baseStyle,
|
||||
Background: baseStyle,
|
||||
|
||||
Proto: ProtoStyles{
|
||||
TCP: baseStyle,
|
||||
UDP: baseStyle,
|
||||
Unix: baseStyle,
|
||||
TCP6: baseStyle,
|
||||
UDP6: baseStyle,
|
||||
},
|
||||
|
||||
State: StateStyles{
|
||||
Listen: baseStyle,
|
||||
Established: baseStyle,
|
||||
TimeWait: baseStyle,
|
||||
CloseWait: baseStyle,
|
||||
SynSent: baseStyle,
|
||||
SynRecv: baseStyle,
|
||||
FinWait1: baseStyle,
|
||||
FinWait2: baseStyle,
|
||||
Closing: baseStyle,
|
||||
LastAck: baseStyle,
|
||||
Closed: baseStyle,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetProtoStyle returns the appropriate style for a protocol
|
||||
func (s *Styles) GetProtoStyle(proto string) lipgloss.Style {
|
||||
switch strings.ToLower(proto) {
|
||||
case "tcp":
|
||||
return s.Proto.TCP
|
||||
case "udp":
|
||||
return s.Proto.UDP
|
||||
case "unix":
|
||||
return s.Proto.Unix
|
||||
case "tcp6":
|
||||
return s.Proto.TCP6
|
||||
case "udp6":
|
||||
return s.Proto.UDP6
|
||||
default:
|
||||
return s.Normal
|
||||
}
|
||||
}
|
||||
|
||||
// GetStateStyle returns the appropriate style for a connection state
|
||||
func (s *Styles) GetStateStyle(state string) lipgloss.Style {
|
||||
switch strings.ToUpper(state) {
|
||||
case "LISTEN":
|
||||
return s.State.Listen
|
||||
case "ESTABLISHED":
|
||||
return s.State.Established
|
||||
case "TIME_WAIT":
|
||||
return s.State.TimeWait
|
||||
case "CLOSE_WAIT":
|
||||
return s.State.CloseWait
|
||||
case "SYN_SENT":
|
||||
return s.State.SynSent
|
||||
case "SYN_RECV":
|
||||
return s.State.SynRecv
|
||||
case "FIN_WAIT1":
|
||||
return s.State.FinWait1
|
||||
case "FIN_WAIT2":
|
||||
return s.State.FinWait2
|
||||
case "CLOSING":
|
||||
return s.State.Closing
|
||||
case "LAST_ACK":
|
||||
return s.State.LastAck
|
||||
case "CLOSED":
|
||||
return s.State.Closed
|
||||
default:
|
||||
return s.Normal
|
||||
}
|
||||
}
|
||||
53
internal/tui/helpers.go
Normal file
53
internal/tui/helpers.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"snitch/internal/collector"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
if max <= 2 {
|
||||
return s[:max]
|
||||
}
|
||||
return s[:max-1] + "…"
|
||||
}
|
||||
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
func stripAnsi(s string) string {
|
||||
return ansiRegex.ReplaceAllString(s, "")
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
||||
}
|
||||
|
||||
func sortFieldLabel(f collector.SortField) string {
|
||||
switch f {
|
||||
case collector.SortByLport:
|
||||
return "port"
|
||||
case collector.SortByProcess:
|
||||
return "proc"
|
||||
case collector.SortByPID:
|
||||
return "pid"
|
||||
case collector.SortByState:
|
||||
return "state"
|
||||
case collector.SortByProto:
|
||||
return "proto"
|
||||
default:
|
||||
return "port"
|
||||
}
|
||||
}
|
||||
|
||||
func formatRemote(addr string, port int) string {
|
||||
if addr == "" || addr == "*" || port == 0 {
|
||||
return "-"
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", addr, port)
|
||||
}
|
||||
|
||||
185
internal/tui/keys.go
Normal file
185
internal/tui/keys.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"snitch/internal/collector"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
// search mode captures all input
|
||||
if m.searchActive {
|
||||
return m.handleSearchKey(msg)
|
||||
}
|
||||
|
||||
// detail view only allows closing
|
||||
if m.showDetail {
|
||||
return m.handleDetailKey(msg)
|
||||
}
|
||||
|
||||
// help view only allows closing
|
||||
if m.showHelp {
|
||||
return m.handleHelpKey(msg)
|
||||
}
|
||||
|
||||
return m.handleNormalKey(msg)
|
||||
}
|
||||
|
||||
func (m model) handleSearchKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.searchActive = false
|
||||
m.searchQuery = ""
|
||||
case "enter":
|
||||
m.searchActive = false
|
||||
m.cursor = 0
|
||||
case "backspace":
|
||||
if len(m.searchQuery) > 0 {
|
||||
m.searchQuery = m.searchQuery[:len(m.searchQuery)-1]
|
||||
}
|
||||
default:
|
||||
if len(msg.String()) == 1 {
|
||||
m.searchQuery += msg.String()
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "esc", "enter", "q":
|
||||
m.showDetail = false
|
||||
m.selected = nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) handleHelpKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "esc", "enter", "q", "?":
|
||||
m.showHelp = false
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "q", "ctrl+c":
|
||||
return m, tea.Sequence(tea.ShowCursor, tea.Quit)
|
||||
|
||||
// navigation
|
||||
case "j", "down":
|
||||
m.moveCursor(1)
|
||||
case "k", "up":
|
||||
m.moveCursor(-1)
|
||||
case "g":
|
||||
m.cursor = 0
|
||||
case "G":
|
||||
visible := m.visibleConnections()
|
||||
if len(visible) > 0 {
|
||||
m.cursor = len(visible) - 1
|
||||
}
|
||||
case "ctrl+d":
|
||||
m.moveCursor(m.pageSize() / 2)
|
||||
case "ctrl+u":
|
||||
m.moveCursor(-m.pageSize() / 2)
|
||||
case "ctrl+f", "pgdown":
|
||||
m.moveCursor(m.pageSize())
|
||||
case "ctrl+b", "pgup":
|
||||
m.moveCursor(-m.pageSize())
|
||||
|
||||
// filter toggles
|
||||
case "t":
|
||||
m.showTCP = !m.showTCP
|
||||
m.clampCursor()
|
||||
case "u":
|
||||
m.showUDP = !m.showUDP
|
||||
m.clampCursor()
|
||||
case "l":
|
||||
m.showListening = !m.showListening
|
||||
m.clampCursor()
|
||||
case "e":
|
||||
m.showEstablished = !m.showEstablished
|
||||
m.clampCursor()
|
||||
case "o":
|
||||
m.showOther = !m.showOther
|
||||
m.clampCursor()
|
||||
case "a":
|
||||
m.showTCP = true
|
||||
m.showUDP = true
|
||||
m.showListening = true
|
||||
m.showEstablished = true
|
||||
m.showOther = true
|
||||
|
||||
// sorting
|
||||
case "s":
|
||||
m.cycleSort()
|
||||
case "S":
|
||||
m.sortReverse = !m.sortReverse
|
||||
m.applySorting()
|
||||
|
||||
// search
|
||||
case "/":
|
||||
m.searchActive = true
|
||||
m.searchQuery = ""
|
||||
|
||||
// actions
|
||||
case "enter", " ":
|
||||
visible := m.visibleConnections()
|
||||
if m.cursor < len(visible) {
|
||||
conn := visible[m.cursor]
|
||||
m.selected = &conn
|
||||
m.showDetail = true
|
||||
}
|
||||
case "r":
|
||||
return m, m.fetchData()
|
||||
case "?":
|
||||
m.showHelp = true
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) moveCursor(delta int) {
|
||||
visible := m.visibleConnections()
|
||||
m.cursor += delta
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
if m.cursor >= len(visible) {
|
||||
m.cursor = len(visible) - 1
|
||||
}
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m model) pageSize() int {
|
||||
size := m.height - 6
|
||||
if size < 1 {
|
||||
return 10
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (m *model) cycleSort() {
|
||||
fields := []collector.SortField{
|
||||
collector.SortByLport,
|
||||
collector.SortByProcess,
|
||||
collector.SortByPID,
|
||||
collector.SortByState,
|
||||
collector.SortByProto,
|
||||
}
|
||||
|
||||
for i, f := range fields {
|
||||
if f == m.sortField {
|
||||
m.sortField = fields[(i+1)%len(fields)]
|
||||
m.applySorting()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.sortField = collector.SortByLport
|
||||
m.applySorting()
|
||||
}
|
||||
|
||||
35
internal/tui/messages.go
Normal file
35
internal/tui/messages.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"snitch/internal/collector"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type tickMsg time.Time
|
||||
|
||||
type dataMsg struct {
|
||||
connections []collector.Connection
|
||||
}
|
||||
|
||||
type errMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m model) tick() tea.Cmd {
|
||||
return tea.Tick(m.interval, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func (m model) fetchData() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
conns, err := collector.GetConnections()
|
||||
if err != nil {
|
||||
return errMsg{err}
|
||||
}
|
||||
return dataMsg{connections: conns}
|
||||
}
|
||||
}
|
||||
|
||||
220
internal/tui/model.go
Normal file
220
internal/tui/model.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"snitch/internal/collector"
|
||||
"snitch/internal/theme"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type model struct {
|
||||
connections []collector.Connection
|
||||
cursor int
|
||||
width int
|
||||
height int
|
||||
|
||||
// filtering
|
||||
showTCP bool
|
||||
showUDP bool
|
||||
showListening bool
|
||||
showEstablished bool
|
||||
showOther bool
|
||||
searchQuery string
|
||||
searchActive bool
|
||||
|
||||
// sorting
|
||||
sortField collector.SortField
|
||||
sortReverse bool
|
||||
|
||||
// ui state
|
||||
theme *theme.Theme
|
||||
showHelp bool
|
||||
showDetail bool
|
||||
selected *collector.Connection
|
||||
interval time.Duration
|
||||
lastRefresh time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Theme string
|
||||
Interval time.Duration
|
||||
TCP bool
|
||||
UDP bool
|
||||
Listening bool
|
||||
Established bool
|
||||
Other bool
|
||||
FilterSet bool // true if user specified any filter flags
|
||||
}
|
||||
|
||||
func New(opts Options) model {
|
||||
interval := opts.Interval
|
||||
if interval == 0 {
|
||||
interval = time.Second
|
||||
}
|
||||
|
||||
// default: show everything
|
||||
showTCP := true
|
||||
showUDP := true
|
||||
showListening := true
|
||||
showEstablished := true
|
||||
showOther := true
|
||||
|
||||
// if user specified filters, use those instead
|
||||
if opts.FilterSet {
|
||||
showTCP = opts.TCP
|
||||
showUDP = opts.UDP
|
||||
showListening = opts.Listening
|
||||
showEstablished = opts.Established
|
||||
showOther = opts.Other
|
||||
|
||||
// if only proto filters set, show all states
|
||||
if !opts.Listening && !opts.Established && !opts.Other {
|
||||
showListening = true
|
||||
showEstablished = true
|
||||
showOther = true
|
||||
}
|
||||
// if only state filters set, show all protos
|
||||
if !opts.TCP && !opts.UDP {
|
||||
showTCP = true
|
||||
showUDP = true
|
||||
}
|
||||
}
|
||||
|
||||
return model{
|
||||
connections: []collector.Connection{},
|
||||
showTCP: showTCP,
|
||||
showUDP: showUDP,
|
||||
showListening: showListening,
|
||||
showEstablished: showEstablished,
|
||||
showOther: showOther,
|
||||
sortField: collector.SortByLport,
|
||||
theme: theme.GetTheme(opts.Theme),
|
||||
interval: interval,
|
||||
lastRefresh: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
tea.HideCursor,
|
||||
m.fetchData(),
|
||||
m.tick(),
|
||||
)
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
return m.handleKey(msg)
|
||||
|
||||
case tickMsg:
|
||||
return m, tea.Batch(m.fetchData(), m.tick())
|
||||
|
||||
case dataMsg:
|
||||
m.connections = msg.connections
|
||||
m.lastRefresh = time.Now()
|
||||
m.applySorting()
|
||||
m.clampCursor()
|
||||
return m, nil
|
||||
|
||||
case errMsg:
|
||||
m.err = msg.err
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.err != nil {
|
||||
return m.renderError()
|
||||
}
|
||||
if m.showHelp {
|
||||
return m.renderHelp()
|
||||
}
|
||||
if m.showDetail && m.selected != nil {
|
||||
return m.renderDetail()
|
||||
}
|
||||
return m.renderMain()
|
||||
}
|
||||
|
||||
func (m *model) applySorting() {
|
||||
direction := collector.SortAsc
|
||||
if m.sortReverse {
|
||||
direction = collector.SortDesc
|
||||
}
|
||||
collector.SortConnections(m.connections, collector.SortOptions{
|
||||
Field: m.sortField,
|
||||
Direction: direction,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *model) clampCursor() {
|
||||
visible := m.visibleConnections()
|
||||
if m.cursor >= len(visible) {
|
||||
m.cursor = len(visible) - 1
|
||||
}
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m model) visibleConnections() []collector.Connection {
|
||||
var result []collector.Connection
|
||||
|
||||
for _, c := range m.connections {
|
||||
if !m.matchesFilters(c) {
|
||||
continue
|
||||
}
|
||||
if m.searchQuery != "" && !m.matchesSearch(c) {
|
||||
continue
|
||||
}
|
||||
result = append(result, c)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (m model) matchesFilters(c collector.Connection) bool {
|
||||
isTCP := c.Proto == "tcp" || c.Proto == "tcp6"
|
||||
isUDP := c.Proto == "udp" || c.Proto == "udp6"
|
||||
|
||||
if isTCP && !m.showTCP {
|
||||
return false
|
||||
}
|
||||
if isUDP && !m.showUDP {
|
||||
return false
|
||||
}
|
||||
|
||||
isListening := c.State == "LISTEN"
|
||||
isEstablished := c.State == "ESTABLISHED"
|
||||
isOther := !isListening && !isEstablished
|
||||
|
||||
if isListening && !m.showListening {
|
||||
return false
|
||||
}
|
||||
if isEstablished && !m.showEstablished {
|
||||
return false
|
||||
}
|
||||
if isOther && !m.showOther {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m model) matchesSearch(c collector.Connection) bool {
|
||||
return containsIgnoreCase(c.Process, m.searchQuery) ||
|
||||
containsIgnoreCase(c.Laddr, m.searchQuery) ||
|
||||
containsIgnoreCase(c.Raddr, m.searchQuery) ||
|
||||
containsIgnoreCase(c.User, m.searchQuery) ||
|
||||
containsIgnoreCase(c.Proto, m.searchQuery) ||
|
||||
containsIgnoreCase(c.State, m.searchQuery)
|
||||
}
|
||||
352
internal/tui/view.go
Normal file
352
internal/tui/view.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"snitch/internal/collector"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (m model) renderMain() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.renderTitle())
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.renderFilters())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(m.renderTableHeader())
|
||||
b.WriteString(m.renderSeparator())
|
||||
b.WriteString(m.renderConnections())
|
||||
b.WriteString("\n")
|
||||
b.WriteString(m.renderStatusLine())
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m model) renderTitle() string {
|
||||
visible := m.visibleConnections()
|
||||
total := len(m.connections)
|
||||
|
||||
left := m.theme.Styles.Header.Render("snitch")
|
||||
|
||||
ago := time.Since(m.lastRefresh).Round(time.Millisecond * 100)
|
||||
right := m.theme.Styles.Normal.Render(fmt.Sprintf("%d/%d connections ↻ %s", len(visible), total, formatDuration(ago)))
|
||||
|
||||
w := m.safeWidth()
|
||||
gap := w - len(stripAnsi(left)) - len(stripAnsi(right)) - 2
|
||||
if gap < 0 {
|
||||
gap = 0
|
||||
}
|
||||
|
||||
return " " + left + strings.Repeat(" ", gap) + right
|
||||
}
|
||||
|
||||
func (m model) renderFilters() string {
|
||||
var parts []string
|
||||
|
||||
if m.showTCP {
|
||||
parts = append(parts, m.theme.Styles.Success.Render("tcp"))
|
||||
} else {
|
||||
parts = append(parts, m.theme.Styles.Normal.Render("tcp"))
|
||||
}
|
||||
|
||||
if m.showUDP {
|
||||
parts = append(parts, m.theme.Styles.Success.Render("udp"))
|
||||
} else {
|
||||
parts = append(parts, m.theme.Styles.Normal.Render("udp"))
|
||||
}
|
||||
|
||||
parts = append(parts, m.theme.Styles.Border.Render("│"))
|
||||
|
||||
if m.showListening {
|
||||
parts = append(parts, m.theme.Styles.Success.Render("listen"))
|
||||
} else {
|
||||
parts = append(parts, m.theme.Styles.Normal.Render("listen"))
|
||||
}
|
||||
|
||||
if m.showEstablished {
|
||||
parts = append(parts, m.theme.Styles.Success.Render("estab"))
|
||||
} else {
|
||||
parts = append(parts, m.theme.Styles.Normal.Render("estab"))
|
||||
}
|
||||
|
||||
if m.showOther {
|
||||
parts = append(parts, m.theme.Styles.Success.Render("other"))
|
||||
} else {
|
||||
parts = append(parts, m.theme.Styles.Normal.Render("other"))
|
||||
}
|
||||
|
||||
left := " " + strings.Join(parts, " ")
|
||||
|
||||
sortLabel := sortFieldLabel(m.sortField)
|
||||
sortDir := "↑"
|
||||
if m.sortReverse {
|
||||
sortDir = "↓"
|
||||
}
|
||||
|
||||
var right string
|
||||
if m.searchActive {
|
||||
right = m.theme.Styles.Warning.Render(fmt.Sprintf("/%s▌", m.searchQuery))
|
||||
} else if m.searchQuery != "" {
|
||||
right = m.theme.Styles.Normal.Render(fmt.Sprintf("filter: %s", m.searchQuery))
|
||||
} else {
|
||||
right = m.theme.Styles.Normal.Render(fmt.Sprintf("sort: %s %s", sortLabel, sortDir))
|
||||
}
|
||||
|
||||
w := m.safeWidth()
|
||||
gap := w - len(stripAnsi(left)) - len(stripAnsi(right)) - 2
|
||||
if gap < 0 {
|
||||
gap = 0
|
||||
}
|
||||
|
||||
return left + strings.Repeat(" ", gap) + right + " "
|
||||
}
|
||||
|
||||
func (m model) renderTableHeader() string {
|
||||
cols := m.columnWidths()
|
||||
|
||||
header := fmt.Sprintf(" %-*s %-*s %-*s %-*s %-*s %s",
|
||||
cols.process, "PROCESS",
|
||||
cols.port, "PORT",
|
||||
cols.proto, "PROTO",
|
||||
cols.state, "STATE",
|
||||
cols.local, "LOCAL",
|
||||
"REMOTE")
|
||||
|
||||
return m.theme.Styles.Header.Render(header) + "\n"
|
||||
}
|
||||
|
||||
func (m model) renderSeparator() string {
|
||||
w := m.width - 4
|
||||
if w < 1 {
|
||||
w = 76
|
||||
}
|
||||
line := " " + strings.Repeat("─", w)
|
||||
return m.theme.Styles.Border.Render(line) + "\n"
|
||||
}
|
||||
|
||||
func (m model) renderConnections() string {
|
||||
var b strings.Builder
|
||||
visible := m.visibleConnections()
|
||||
pageSize := m.pageSize()
|
||||
|
||||
if len(visible) == 0 {
|
||||
empty := "\n " + m.theme.Styles.Normal.Render("no connections match filters") + "\n"
|
||||
return empty
|
||||
}
|
||||
|
||||
start := m.scrollOffset(pageSize, len(visible))
|
||||
|
||||
for i := 0; i < pageSize; i++ {
|
||||
idx := start + i
|
||||
if idx >= len(visible) {
|
||||
b.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
isSelected := idx == m.cursor
|
||||
b.WriteString(m.renderRow(visible[idx], isSelected))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m model) renderRow(c collector.Connection, selected bool) string {
|
||||
cols := m.columnWidths()
|
||||
|
||||
indicator := " "
|
||||
if selected {
|
||||
indicator = m.theme.Styles.Success.Render("▸ ")
|
||||
}
|
||||
|
||||
process := truncate(c.Process, cols.process)
|
||||
if process == "" {
|
||||
process = "–"
|
||||
}
|
||||
|
||||
port := fmt.Sprintf("%d", c.Lport)
|
||||
proto := c.Proto
|
||||
state := c.State
|
||||
if state == "" {
|
||||
state = "–"
|
||||
}
|
||||
|
||||
local := c.Laddr
|
||||
if local == "*" || local == "" {
|
||||
local = "*"
|
||||
}
|
||||
|
||||
remote := formatRemote(c.Raddr, c.Rport)
|
||||
|
||||
// apply styling
|
||||
protoStyled := m.theme.Styles.GetProtoStyle(proto).Render(fmt.Sprintf("%-*s", cols.proto, proto))
|
||||
stateStyled := m.theme.Styles.GetStateStyle(state).Render(fmt.Sprintf("%-*s", cols.state, truncate(state, cols.state)))
|
||||
|
||||
row := fmt.Sprintf("%s%-*s %-*s %s %s %-*s %s",
|
||||
indicator,
|
||||
cols.process, process,
|
||||
cols.port, port,
|
||||
protoStyled,
|
||||
stateStyled,
|
||||
cols.local, truncate(local, cols.local),
|
||||
truncate(remote, cols.remote))
|
||||
|
||||
if selected {
|
||||
return m.theme.Styles.Selected.Render(row) + "\n"
|
||||
}
|
||||
|
||||
return m.theme.Styles.Normal.Render(row) + "\n"
|
||||
}
|
||||
|
||||
func (m model) renderStatusLine() string {
|
||||
left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state s sort / search ? help q quit")
|
||||
|
||||
return left
|
||||
}
|
||||
|
||||
func (m model) renderError() string {
|
||||
return fmt.Sprintf("\n %s\n\n press q to quit\n",
|
||||
m.theme.Styles.Error.Render(fmt.Sprintf("error: %v", m.err)))
|
||||
}
|
||||
|
||||
func (m model) renderHelp() string {
|
||||
help := `
|
||||
navigation
|
||||
──────────
|
||||
j/k ↑/↓ move cursor
|
||||
g/G jump to top/bottom
|
||||
ctrl+d/u half page down/up
|
||||
enter show connection details
|
||||
|
||||
filters
|
||||
───────
|
||||
t toggle tcp
|
||||
u toggle udp
|
||||
l toggle listening
|
||||
e toggle established
|
||||
o toggle other states
|
||||
a reset all filters
|
||||
|
||||
sorting
|
||||
───────
|
||||
s cycle sort field
|
||||
S reverse sort order
|
||||
|
||||
other
|
||||
─────
|
||||
/ search
|
||||
r refresh now
|
||||
q quit
|
||||
|
||||
press ? or esc to close
|
||||
`
|
||||
return m.theme.Styles.Normal.Render(help)
|
||||
}
|
||||
|
||||
func (m model) renderDetail() string {
|
||||
if m.selected == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
c := m.selected
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(" " + m.theme.Styles.Header.Render("connection details") + "\n")
|
||||
b.WriteString(" " + m.theme.Styles.Border.Render(strings.Repeat("─", 40)) + "\n\n")
|
||||
|
||||
fields := []struct {
|
||||
label string
|
||||
value string
|
||||
}{
|
||||
{"process", c.Process},
|
||||
{"pid", fmt.Sprintf("%d", c.PID)},
|
||||
{"user", c.User},
|
||||
{"protocol", c.Proto},
|
||||
{"state", c.State},
|
||||
{"local", fmt.Sprintf("%s:%d", c.Laddr, c.Lport)},
|
||||
{"remote", fmt.Sprintf("%s:%d", c.Raddr, c.Rport)},
|
||||
{"interface", c.Interface},
|
||||
{"inode", fmt.Sprintf("%d", c.Inode)},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
val := f.value
|
||||
if val == "" || val == "0" || val == ":0" {
|
||||
val = "–"
|
||||
}
|
||||
line := fmt.Sprintf(" %-12s %s\n", m.theme.Styles.Header.Render(f.label), val)
|
||||
b.WriteString(line)
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(" " + m.theme.Styles.Normal.Render("press esc to close") + "\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m model) scrollOffset(pageSize, total int) int {
|
||||
if total <= pageSize {
|
||||
return 0
|
||||
}
|
||||
|
||||
// keep cursor roughly centered
|
||||
offset := m.cursor - pageSize/2
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if offset > total-pageSize {
|
||||
offset = total - pageSize
|
||||
}
|
||||
return offset
|
||||
}
|
||||
|
||||
type columns struct {
|
||||
process int
|
||||
port int
|
||||
proto int
|
||||
state int
|
||||
local int
|
||||
remote int
|
||||
}
|
||||
|
||||
func (m model) columnWidths() columns {
|
||||
available := m.safeWidth() - 16
|
||||
|
||||
c := columns{
|
||||
process: 16,
|
||||
port: 6,
|
||||
proto: 5,
|
||||
state: 11,
|
||||
local: 15,
|
||||
remote: 20,
|
||||
}
|
||||
|
||||
used := c.process + c.port + c.proto + c.state + c.local + c.remote
|
||||
extra := available - used
|
||||
|
||||
if extra > 0 {
|
||||
c.process += extra / 3
|
||||
c.remote += extra - extra/3
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (m model) safeWidth() int {
|
||||
if m.width < 80 {
|
||||
return 80
|
||||
}
|
||||
return m.width
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Second {
|
||||
return fmt.Sprintf("%dms", d.Milliseconds())
|
||||
}
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||||
}
|
||||
return fmt.Sprintf("%.0fm", d.Minutes())
|
||||
}
|
||||
Reference in New Issue
Block a user