From bdc4de02295d12fedea0d4c82ec5e0217a4b2eae Mon Sep 17 00:00:00 2001 From: Karol Broda <122811026+karol-broda@users.noreply.github.com> Date: Mon, 29 Dec 2025 19:47:32 +0100 Subject: [PATCH] feat: add port search, remote sort, export, and process info (#27) --- cmd/ls.go | 72 +++++ internal/collector/collector_darwin.go | 28 +- internal/collector/collector_linux.go | 51 ++-- internal/collector/collector_test.go | 56 ++++ internal/collector/sort_test.go | 72 +++++ internal/collector/types.go | 2 + internal/tui/helpers.go | 4 + internal/tui/keys.go | 94 ++++++- internal/tui/model.go | 82 +++++- internal/tui/model_test.go | 348 ++++++++++++++++++++++++- internal/tui/symbols.go | 4 +- internal/tui/view.go | 118 ++++++++- 12 files changed, 906 insertions(+), 25 deletions(-) diff --git a/cmd/ls.go b/cmd/ls.go index 81f195d..04d5cd8 100644 --- a/cmd/ls.go +++ b/cmd/ls.go @@ -27,6 +27,7 @@ import ( // ls-specific flags var ( outputFormat string + outputFile string noHeaders bool showTimestamp bool sortBy string @@ -72,9 +73,77 @@ func runListCommand(outputFormat string, args []string) { selectedFields = strings.Split(fields, ",") } + // handle file output + if outputFile != "" { + writeToFile(rt.Connections, outputFile, selectedFields) + return + } + renderList(rt.Connections, outputFormat, selectedFields) } +func writeToFile(connections []collector.Connection, filename string, selectedFields []string) { + file, err := os.Create(filename) + if err != nil { + log.Fatalf("failed to create file: %v", err) + } + defer errutil.Close(file) + + // determine format from extension + format := "csv" + lowerFilename := strings.ToLower(filename) + if strings.HasSuffix(lowerFilename, ".json") { + format = "json" + } else if strings.HasSuffix(lowerFilename, ".tsv") { + format = "tsv" + } + + if len(selectedFields) == 0 { + selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"} + if showTimestamp { + selectedFields = append([]string{"ts"}, selectedFields...) + } + } + + switch format { + case "json": + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + if err := encoder.Encode(connections); err != nil { + log.Fatalf("failed to write JSON: %v", err) + } + case "tsv": + writeDelimited(file, connections, "\t", !noHeaders, selectedFields) + default: + writeDelimited(file, connections, ",", !noHeaders, selectedFields) + } + + fmt.Fprintf(os.Stderr, "exported %d connections to %s\n", len(connections), filename) +} + +func writeDelimited(w io.Writer, connections []collector.Connection, delimiter string, headers bool, selectedFields []string) { + if headers { + headerRow := make([]string, len(selectedFields)) + for i, field := range selectedFields { + headerRow[i] = strings.ToUpper(field) + } + _, _ = fmt.Fprintln(w, strings.Join(headerRow, delimiter)) + } + + for _, conn := range connections { + fieldMap := getFieldMap(conn) + row := make([]string, len(selectedFields)) + for i, field := range selectedFields { + val := fieldMap[field] + if delimiter == "," && (strings.Contains(val, ",") || strings.Contains(val, "\"") || strings.Contains(val, "\n")) { + val = "\"" + strings.ReplaceAll(val, "\"", "\"\"") + "\"" + } + row[i] = val + } + _, _ = fmt.Fprintln(w, strings.Join(row, delimiter)) + } +} + func renderList(connections []collector.Connection, format string, selectedFields []string) { switch format { case "json": @@ -122,6 +191,8 @@ func getFieldMap(c collector.Connection) map[string]string { return map[string]string{ "pid": strconv.Itoa(c.PID), "process": c.Process, + "cmdline": c.Cmdline, + "cwd": c.Cwd, "user": c.User, "uid": strconv.Itoa(c.UID), "proto": c.Proto, @@ -395,6 +466,7 @@ func init() { // ls-specific flags lsCmd.Flags().StringVarP(&outputFormat, "output", "o", cfg.Defaults.OutputFormat, "Output format (table, wide, json, csv)") + lsCmd.Flags().StringVarP(&outputFile, "output-file", "O", "", "Write output to file (format detected from extension: .csv, .tsv, .json)") lsCmd.Flags().BoolVar(&noHeaders, "no-headers", cfg.Defaults.NoHeaders, "Omit headers for table/csv output") lsCmd.Flags().BoolVar(&showTimestamp, "ts", false, "Include timestamp in output") lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)") diff --git a/internal/collector/collector_darwin.go b/internal/collector/collector_darwin.go index 90291ef..7b4cda6 100644 --- a/internal/collector/collector_darwin.go +++ b/internal/collector/collector_darwin.go @@ -37,6 +37,19 @@ static const char* get_username(int uid) { return pw->pw_name; } +// get current working directory for a process +static int get_proc_cwd(int pid, char *path, int pathlen) { + struct proc_vnodepathinfo vpi; + int ret = proc_pidinfo(pid, PROC_PIDVNODEPATHINFO, 0, &vpi, sizeof(vpi)); + if (ret <= 0) { + path[0] = '\0'; + return -1; + } + strncpy(path, vpi.pvi_cdir.vip_path, pathlen - 1); + path[pathlen - 1] = '\0'; + return 0; +} + // socket info extraction - handles the union properly in C typedef struct { int family; @@ -164,6 +177,7 @@ func listAllPids() ([]int, error) { func getConnectionsForPid(pid int) ([]Connection, error) { procName := getProcessName(pid) + cwd := getProcessCwd(pid) uid := int(C.get_proc_uid(C.int(pid))) user := "" if uid >= 0 { @@ -198,7 +212,7 @@ func getConnectionsForPid(pid int) ([]Connection, error) { continue } - conn, ok := getSocketInfo(pid, int(fdInfo.proc_fd), procName, uid, user) + conn, ok := getSocketInfo(pid, int(fdInfo.proc_fd), procName, cwd, uid, user) if ok { connections = append(connections, conn) } @@ -207,7 +221,7 @@ func getConnectionsForPid(pid int) ([]Connection, error) { return connections, nil } -func getSocketInfo(pid, fd int, procName string, uid int, user string) (Connection, bool) { +func getSocketInfo(pid, fd int, procName, cwd string, uid int, user string) (Connection, bool) { var info C.socket_info_t ret := C.get_socket_info(C.int(pid), C.int(fd), &info) @@ -276,6 +290,7 @@ func getSocketInfo(pid, fd int, procName string, uid int, user string) (Connecti Rport: int(info.rport), PID: pid, Process: procName, + Cwd: cwd, UID: uid, User: user, Interface: guessNetworkInterface(laddr), @@ -293,6 +308,15 @@ func getProcessName(pid int) string { return C.GoString(&name[0]) } +func getProcessCwd(pid int) string { + var path [1024]C.char + ret := C.get_proc_cwd(C.int(pid), &path[0], 1024) + if ret != 0 { + return "" + } + return C.GoString(&path[0]) +} + func ipv4ToString(addr uint32) string { ip := make(net.IP, 4) ip[0] = byte(addr) diff --git a/internal/collector/collector_linux.go b/internal/collector/collector_linux.go index 0b9bafd..ba435ec 100644 --- a/internal/collector/collector_linux.go +++ b/internal/collector/collector_linux.go @@ -125,6 +125,8 @@ func GetAllConnections() ([]Connection, error) { type processInfo struct { pid int command string + cmdline string + cwd string uid int user string } @@ -248,34 +250,45 @@ func scanProcessSockets(pid int) []inodeEntry { func getProcessInfo(pid int) (*processInfo, error) { info := &processInfo{pid: pid} + pidStr := strconv.Itoa(pid) - commPath := filepath.Join("/proc", strconv.Itoa(pid), "comm") + commPath := filepath.Join("/proc", pidStr, "comm") commData, err := os.ReadFile(commPath) if err == nil && len(commData) > 0 { info.command = strings.TrimSpace(string(commData)) } - if info.command == "" { - cmdlinePath := filepath.Join("/proc", strconv.Itoa(pid), "cmdline") - cmdlineData, err := os.ReadFile(cmdlinePath) - if err != nil { - return nil, err - } - - if len(cmdlineData) > 0 { - parts := bytes.Split(cmdlineData, []byte{0}) - if len(parts) > 0 && len(parts[0]) > 0 { - fullPath := string(parts[0]) - baseName := filepath.Base(fullPath) - if strings.Contains(baseName, " ") { - baseName = strings.Fields(baseName)[0] - } - info.command = baseName + cmdlinePath := filepath.Join("/proc", pidStr, "cmdline") + cmdlineData, err := os.ReadFile(cmdlinePath) + if err == nil && len(cmdlineData) > 0 { + parts := bytes.Split(cmdlineData, []byte{0}) + var args []string + for _, p := range parts { + if len(p) > 0 { + args = append(args, string(p)) } } + info.cmdline = strings.Join(args, " ") + + if info.command == "" && len(parts) > 0 && len(parts[0]) > 0 { + fullPath := string(parts[0]) + baseName := filepath.Base(fullPath) + if strings.Contains(baseName, " ") { + baseName = strings.Fields(baseName)[0] + } + info.command = baseName + } + } else if info.command == "" { + return nil, err } - statusPath := filepath.Join("/proc", strconv.Itoa(pid), "status") + cwdPath := filepath.Join("/proc", pidStr, "cwd") + cwdLink, err := os.Readlink(cwdPath) + if err == nil { + info.cwd = cwdLink + } + + statusPath := filepath.Join("/proc", pidStr, "status") statusFile, err := os.Open(statusPath) if err != nil { return info, nil @@ -361,6 +374,8 @@ func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*process if procInfo, exists := inodeMap[inode]; exists { conn.PID = procInfo.pid conn.Process = procInfo.command + conn.Cmdline = procInfo.cmdline + conn.Cwd = procInfo.cwd conn.UID = procInfo.uid conn.User = procInfo.user } diff --git a/internal/collector/collector_test.go b/internal/collector/collector_test.go index 207c209..6517945 100644 --- a/internal/collector/collector_test.go +++ b/internal/collector/collector_test.go @@ -114,4 +114,60 @@ func BenchmarkBuildInodeMap(b *testing.B) { for i := 0; i < b.N; i++ { _, _ = buildInodeToProcessMap() } +} + +func TestConnectionHasCmdlineAndCwd(t *testing.T) { + conns, err := GetConnections() + if err != nil { + t.Fatalf("GetConnections() returned an error: %v", err) + } + + if len(conns) == 0 { + t.Skip("no connections to test") + } + + // find a connection with a PID (owned by some process) + var connWithProcess *Connection + for i := range conns { + if conns[i].PID > 0 { + connWithProcess = &conns[i] + break + } + } + + if connWithProcess == nil { + t.Skip("no connections with associated process found") + } + + t.Logf("testing connection: pid=%d process=%s", connWithProcess.PID, connWithProcess.Process) + + // cmdline and cwd should be populated for connections with PIDs + // note: they might be empty if we don't have permission to read them + if connWithProcess.Cmdline != "" { + t.Logf("cmdline: %s", connWithProcess.Cmdline) + } else { + t.Logf("cmdline is empty (might be permission issue)") + } + + if connWithProcess.Cwd != "" { + t.Logf("cwd: %s", connWithProcess.Cwd) + } else { + t.Logf("cwd is empty (might be permission issue)") + } +} + +func TestGetProcessInfoPopulatesCmdlineAndCwd(t *testing.T) { + // test that getProcessInfo correctly populates cmdline and cwd for our own process + info, err := getProcessInfo(1) // init process (usually has cwd of /) + if err != nil { + t.Logf("could not get process info for pid 1: %v", err) + t.Skip("skipping - may not have permission") + } + + t.Logf("pid 1 info: command=%s cmdline=%s cwd=%s", info.command, info.cmdline, info.cwd) + + // at minimum, we should have a command name + if info.command == "" && info.cmdline == "" { + t.Error("expected either command or cmdline to be populated") + } } \ No newline at end of file diff --git a/internal/collector/sort_test.go b/internal/collector/sort_test.go index 43cf35a..eb3942c 100644 --- a/internal/collector/sort_test.go +++ b/internal/collector/sort_test.go @@ -128,3 +128,75 @@ func TestSortByTimestamp(t *testing.T) { } } +func TestSortByRemoteAddr(t *testing.T) { + conns := []Connection{ + {Raddr: "192.168.1.100", Rport: 443}, + {Raddr: "10.0.0.1", Rport: 80}, + {Raddr: "172.16.0.50", Rport: 8080}, + } + + t.Run("sort by raddr ascending", func(t *testing.T) { + c := make([]Connection, len(conns)) + copy(c, conns) + + SortConnections(c, SortOptions{Field: SortByRaddr, Direction: SortAsc}) + + if c[0].Raddr != "10.0.0.1" { + t.Errorf("expected '10.0.0.1' first, got '%s'", c[0].Raddr) + } + if c[1].Raddr != "172.16.0.50" { + t.Errorf("expected '172.16.0.50' second, got '%s'", c[1].Raddr) + } + if c[2].Raddr != "192.168.1.100" { + t.Errorf("expected '192.168.1.100' last, got '%s'", c[2].Raddr) + } + }) + + t.Run("sort by raddr descending", func(t *testing.T) { + c := make([]Connection, len(conns)) + copy(c, conns) + + SortConnections(c, SortOptions{Field: SortByRaddr, Direction: SortDesc}) + + if c[0].Raddr != "192.168.1.100" { + t.Errorf("expected '192.168.1.100' first, got '%s'", c[0].Raddr) + } + }) +} + +func TestSortByRemotePort(t *testing.T) { + conns := []Connection{ + {Raddr: "192.168.1.1", Rport: 443}, + {Raddr: "192.168.1.2", Rport: 80}, + {Raddr: "192.168.1.3", Rport: 8080}, + } + + t.Run("sort by rport ascending", func(t *testing.T) { + c := make([]Connection, len(conns)) + copy(c, conns) + + SortConnections(c, SortOptions{Field: SortByRport, Direction: SortAsc}) + + if c[0].Rport != 80 { + t.Errorf("expected port 80 first, got %d", c[0].Rport) + } + if c[1].Rport != 443 { + t.Errorf("expected port 443 second, got %d", c[1].Rport) + } + if c[2].Rport != 8080 { + t.Errorf("expected port 8080 last, got %d", c[2].Rport) + } + }) + + t.Run("sort by rport descending", func(t *testing.T) { + c := make([]Connection, len(conns)) + copy(c, conns) + + SortConnections(c, SortOptions{Field: SortByRport, Direction: SortDesc}) + + if c[0].Rport != 8080 { + t.Errorf("expected port 8080 first, got %d", c[0].Rport) + } + }) +} + diff --git a/internal/collector/types.go b/internal/collector/types.go index 3cfeeb7..a6a5004 100644 --- a/internal/collector/types.go +++ b/internal/collector/types.go @@ -6,6 +6,8 @@ type Connection struct { TS time.Time `json:"ts"` PID int `json:"pid"` Process string `json:"process"` + Cmdline string `json:"cmdline,omitempty"` + Cwd string `json:"cwd,omitempty"` User string `json:"user"` UID int `json:"uid"` Proto string `json:"proto"` diff --git a/internal/tui/helpers.go b/internal/tui/helpers.go index c49daa6..807a51a 100644 --- a/internal/tui/helpers.go +++ b/internal/tui/helpers.go @@ -38,6 +38,10 @@ func sortFieldLabel(f collector.SortField) string { return "state" case collector.SortByProto: return "proto" + case collector.SortByRaddr: + return "raddr" + case collector.SortByRport: + return "rport" default: return "port" } diff --git a/internal/tui/keys.go b/internal/tui/keys.go index cdafc65..93643c5 100644 --- a/internal/tui/keys.go +++ b/internal/tui/keys.go @@ -2,10 +2,12 @@ package tui import ( "fmt" - "github.com/karol-broda/snitch/internal/collector" + "strings" "time" tea "github.com/charmbracelet/bubbletea" + + "github.com/karol-broda/snitch/internal/collector" ) func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { @@ -14,6 +16,11 @@ func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m.handleSearchKey(msg) } + // export modal captures all input + if m.showExportModal { + return m.handleExportKey(msg) + } + // kill confirmation dialog if m.showKillConfirm { return m.handleKillConfirmKey(msg) @@ -52,6 +59,82 @@ func (m model) handleSearchKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } +func (m model) handleExportKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "esc": + m.showExportModal = false + m.exportFilename = "" + m.exportFormat = "" + m.exportError = "" + + case "tab": + // toggle format + if m.exportFormat == "tsv" { + m.exportFormat = "csv" + } else { + m.exportFormat = "tsv" + } + m.exportError = "" + + case "enter": + // build final filename with extension + filename := m.exportFilename + if filename == "" { + filename = "connections" + } + + ext := ".csv" + if m.exportFormat == "tsv" { + ext = ".tsv" + } + + // only add extension if not already present + if !strings.HasSuffix(strings.ToLower(filename), ".csv") && + !strings.HasSuffix(strings.ToLower(filename), ".tsv") { + filename = filename + ext + } + m.exportFilename = filename + + err := m.exportConnections() + if err != nil { + m.exportError = err.Error() + return m, nil + } + + visible := m.visibleConnections() + m.statusMessage = fmt.Sprintf("%s exported %d connections to %s", SymbolSuccess, len(visible), filename) + m.statusExpiry = time.Now().Add(3 * time.Second) + m.showExportModal = false + m.exportFilename = "" + m.exportFormat = "" + m.exportError = "" + return m, clearStatusAfter(3 * time.Second) + + case "backspace": + if len(m.exportFilename) > 0 { + m.exportFilename = m.exportFilename[:len(m.exportFilename)-1] + } + m.exportError = "" + + default: + // only accept valid filename characters + char := msg.String() + if len(char) == 1 && isValidFilenameChar(char[0]) { + m.exportFilename += char + m.exportError = "" + } + } + return m, nil +} + +func isValidFilenameChar(c byte) bool { + // allow alphanumeric, dash, underscore, dot + return (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c == '-' || c == '_' || c == '.' +} + func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "esc", "enter", "q": @@ -157,6 +240,13 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.searchActive = true m.searchQuery = "" + // export + case "x": + m.showExportModal = true + m.exportFilename = "" + m.exportFormat = "csv" + m.exportError = "" + // actions case "enter", " ": visible := m.visibleConnections() @@ -276,6 +366,8 @@ func (m *model) cycleSort() { collector.SortByPID, collector.SortByState, collector.SortByProto, + collector.SortByRaddr, + collector.SortByRport, } for i, f := range fields { diff --git a/internal/tui/model.go b/internal/tui/model.go index b95aeb3..1af7e88 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -2,6 +2,9 @@ package tui import ( "fmt" + "os" + "strconv" + "strings" "time" tea "github.com/charmbracelet/bubbletea" @@ -54,6 +57,12 @@ type model struct { statusMessage string statusExpiry time.Time + // export modal + showExportModal bool + exportFilename string + exportFormat string // "csv" or "tsv" + exportError string + // state persistence rememberState bool } @@ -214,6 +223,11 @@ func (m model) View() string { return m.overlayModal(main, m.renderKillModal()) } + // overlay export modal on top of main view + if m.showExportModal { + return m.overlayModal(main, m.renderExportModal()) + } + return main } @@ -289,12 +303,19 @@ func (m model) matchesFilters(c collector.Connection) bool { } func (m model) matchesSearch(c collector.Connection) bool { + lportStr := strconv.Itoa(c.Lport) + rportStr := strconv.Itoa(c.Rport) + pidStr := strconv.Itoa(c.PID) + return containsIgnoreCase(c.Process, m.searchQuery) || containsIgnoreCase(c.Laddr, m.searchQuery) || containsIgnoreCase(c.Raddr, m.searchQuery) || containsIgnoreCase(c.User, m.searchQuery) || containsIgnoreCase(c.Proto, m.searchQuery) || - containsIgnoreCase(c.State, m.searchQuery) + containsIgnoreCase(c.State, m.searchQuery) || + containsIgnoreCase(lportStr, m.searchQuery) || + containsIgnoreCase(rportStr, m.searchQuery) || + containsIgnoreCase(pidStr, m.searchQuery) } func (m model) isWatched(pid int) bool { @@ -340,3 +361,62 @@ func (m model) saveState() { state.SaveAsync(m.currentState()) } } + +// exportConnections writes visible connections to a file in csv or tsv format +func (m model) exportConnections() error { + visible := m.visibleConnections() + + if len(visible) == 0 { + return fmt.Errorf("no connections to export") + } + + file, err := os.Create(m.exportFilename) + if err != nil { + return err + } + defer func() { _ = file.Close() }() + + // determine delimiter from format selection or filename + delimiter := "," + if m.exportFormat == "tsv" || strings.HasSuffix(strings.ToLower(m.exportFilename), ".tsv") { + delimiter = "\t" + } + + header := []string{"PID", "PROCESS", "USER", "PROTO", "STATE", "LADDR", "LPORT", "RADDR", "RPORT"} + _, err = file.WriteString(strings.Join(header, delimiter) + "\n") + if err != nil { + return err + } + + for _, c := range visible { + // escape fields that might contain delimiter + process := escapeField(c.Process, delimiter) + user := escapeField(c.User, delimiter) + + row := []string{ + strconv.Itoa(c.PID), + process, + user, + c.Proto, + c.State, + c.Laddr, + strconv.Itoa(c.Lport), + c.Raddr, + strconv.Itoa(c.Rport), + } + _, err = file.WriteString(strings.Join(row, delimiter) + "\n") + if err != nil { + return err + } + } + + return nil +} + +// escapeField quotes a field if it contains the delimiter or quotes +func escapeField(s, delimiter string) string { + if strings.Contains(s, delimiter) || strings.Contains(s, "\"") || strings.Contains(s, "\n") { + return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\"" + } + return s +} diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 0aef3a8..3fc3567 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -1,12 +1,15 @@ package tui import ( - "github.com/karol-broda/snitch/internal/collector" + "os" + "path/filepath" + "strings" "testing" "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/x/exp/teatest" + "github.com/karol-broda/snitch/internal/collector" ) func TestTUI_InitialState(t *testing.T) { @@ -430,3 +433,346 @@ func TestTUI_FormatRemoteHelper(t *testing.T) { } } +func TestTUI_MatchesSearchPort(t *testing.T) { + m := New(Options{Theme: "dark"}) + + tests := []struct { + name string + searchQuery string + conn collector.Connection + expected bool + }{ + { + name: "matches local port", + searchQuery: "3000", + conn: collector.Connection{Lport: 3000}, + expected: true, + }, + { + name: "matches remote port", + searchQuery: "443", + conn: collector.Connection{Rport: 443}, + expected: true, + }, + { + name: "matches pid", + searchQuery: "1234", + conn: collector.Connection{PID: 1234}, + expected: true, + }, + { + name: "partial port match", + searchQuery: "80", + conn: collector.Connection{Lport: 8080}, + expected: true, + }, + { + name: "no match", + searchQuery: "9999", + conn: collector.Connection{Lport: 80, Rport: 443, PID: 1234}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + m.searchQuery = tc.searchQuery + result := m.matchesSearch(tc.conn) + if result != tc.expected { + t.Errorf("matchesSearch() = %v, want %v", result, tc.expected) + } + }) + } +} + +func TestTUI_SortCycleIncludesRemote(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + + // start at default (Lport) + if m.sortField != collector.SortByLport { + t.Fatalf("expected initial sort field to be lport, got %v", m.sortField) + } + + // cycle through all fields and verify raddr and rport are included + foundRaddr := false + foundRport := false + seenFields := make(map[collector.SortField]bool) + + for i := 0; i < 10; i++ { + m.cycleSort() + seenFields[m.sortField] = true + + if m.sortField == collector.SortByRaddr { + foundRaddr = true + } + if m.sortField == collector.SortByRport { + foundRport = true + } + + if foundRaddr && foundRport { + break + } + } + + if !foundRaddr { + t.Error("expected sort cycle to include SortByRaddr") + } + if !foundRport { + t.Error("expected sort cycle to include SortByRport") + } +} + +func TestTUI_ExportModal(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + m.width = 120 + m.height = 40 + + // initially export modal should not be shown + if m.showExportModal { + t.Fatal("expected showExportModal to be false initially") + } + + // press 'x' to open export modal + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}}) + m = newModel.(model) + + if !m.showExportModal { + t.Error("expected showExportModal to be true after pressing 'x'") + } + + // type filename + for _, c := range "test.csv" { + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{c}}) + m = newModel.(model) + } + + if m.exportFilename != "test.csv" { + t.Errorf("expected exportFilename to be 'test.csv', got '%s'", m.exportFilename) + } + + // escape should close modal + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + m = newModel.(model) + + if m.showExportModal { + t.Error("expected showExportModal to be false after escape") + } + if m.exportFilename != "" { + t.Error("expected exportFilename to be cleared after escape") + } +} + +func TestTUI_ExportModalDefaultFilename(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + m.width = 120 + m.height = 40 + + // add test data + m.connections = []collector.Connection{ + {PID: 1234, Process: "nginx", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80}, + } + + // open export modal + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}}) + m = newModel.(model) + + // render export modal should show default filename hint + view := m.View() + if view == "" { + t.Error("expected non-empty view with export modal") + } +} + +func TestTUI_ExportModalBackspace(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + m.width = 120 + m.height = 40 + + // open export modal + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}}) + m = newModel.(model) + + // type filename + for _, c := range "test.csv" { + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{c}}) + m = newModel.(model) + } + + // backspace should remove last character + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyBackspace}) + m = newModel.(model) + + if m.exportFilename != "test.cs" { + t.Errorf("expected 'test.cs' after backspace, got '%s'", m.exportFilename) + } +} + +func TestTUI_ExportConnectionsCSV(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + + m.connections = []collector.Connection{ + {PID: 1234, Process: "nginx", User: "www-data", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80, Raddr: "*", Rport: 0}, + {PID: 5678, Process: "node", User: "node", Proto: "tcp", State: "ESTABLISHED", Laddr: "192.168.1.1", Lport: 3000, Raddr: "10.0.0.1", Rport: 443}, + } + + tmpDir := t.TempDir() + csvPath := filepath.Join(tmpDir, "test_export.csv") + m.exportFilename = csvPath + + err := m.exportConnections() + if err != nil { + t.Fatalf("exportConnections() failed: %v", err) + } + + content, err := os.ReadFile(csvPath) + if err != nil { + t.Fatalf("failed to read exported file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + if len(lines) != 3 { + t.Errorf("expected 3 lines (header + 2 data), got %d", len(lines)) + } + + if !strings.Contains(lines[0], "PID") || !strings.Contains(lines[0], "PROCESS") { + t.Error("header line should contain PID and PROCESS") + } + + if !strings.Contains(lines[1], "nginx") || !strings.Contains(lines[1], "1234") { + t.Error("first data line should contain nginx and 1234") + } + + if !strings.Contains(lines[2], "node") || !strings.Contains(lines[2], "5678") { + t.Error("second data line should contain node and 5678") + } +} + +func TestTUI_ExportConnectionsTSV(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + + m.connections = []collector.Connection{ + {PID: 1234, Process: "nginx", User: "www-data", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80, Raddr: "*", Rport: 0}, + } + + tmpDir := t.TempDir() + tsvPath := filepath.Join(tmpDir, "test_export.tsv") + m.exportFilename = tsvPath + + err := m.exportConnections() + if err != nil { + t.Fatalf("exportConnections() failed: %v", err) + } + + content, err := os.ReadFile(tsvPath) + if err != nil { + t.Fatalf("failed to read exported file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + + // TSV should use tabs + if !strings.Contains(lines[0], "\t") { + t.Error("TSV file should use tabs as delimiters") + } + + // CSV delimiter should not be present between fields + fields := strings.Split(lines[1], "\t") + if len(fields) < 9 { + t.Errorf("expected at least 9 tab-separated fields, got %d", len(fields)) + } +} + +func TestTUI_ExportWithFilters(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + m.showTCP = true + m.showUDP = false + + m.connections = []collector.Connection{ + {PID: 1, Process: "tcp_proc", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80}, + {PID: 2, Process: "udp_proc", Proto: "udp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 53}, + } + + tmpDir := t.TempDir() + csvPath := filepath.Join(tmpDir, "filtered_export.csv") + m.exportFilename = csvPath + + err := m.exportConnections() + if err != nil { + t.Fatalf("exportConnections() failed: %v", err) + } + + content, err := os.ReadFile(csvPath) + if err != nil { + t.Fatalf("failed to read exported file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + + // should only have header + 1 TCP connection (UDP filtered out) + if len(lines) != 2 { + t.Errorf("expected 2 lines (header + 1 TCP), got %d", len(lines)) + } + + if strings.Contains(string(content), "udp_proc") { + t.Error("UDP connection should not be exported when UDP filter is off") + } +} + +func TestTUI_ExportFormatToggle(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + m.width = 120 + m.height = 40 + + // open export modal + newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}}) + m = newModel.(model) + + // default format should be csv + if m.exportFormat != "csv" { + t.Errorf("expected default format 'csv', got '%s'", m.exportFormat) + } + + // tab should toggle to tsv + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab}) + m = newModel.(model) + + if m.exportFormat != "tsv" { + t.Errorf("expected format 'tsv' after tab, got '%s'", m.exportFormat) + } + + // tab again should toggle back to csv + newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab}) + m = newModel.(model) + + if m.exportFormat != "csv" { + t.Errorf("expected format 'csv' after second tab, got '%s'", m.exportFormat) + } +} + +func TestTUI_ExportModalRenderWithStats(t *testing.T) { + m := New(Options{Theme: "dark", Interval: time.Hour}) + m.width = 120 + m.height = 40 + + m.connections = []collector.Connection{ + {PID: 1, Process: "nginx", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80}, + {PID: 2, Process: "postgres", Proto: "tcp", State: "LISTEN", Laddr: "127.0.0.1", Lport: 5432}, + {PID: 3, Process: "node", Proto: "tcp", State: "ESTABLISHED", Laddr: "192.168.1.1", Lport: 3000}, + } + + m.showExportModal = true + m.exportFormat = "csv" + + view := m.View() + + // modal should contain summary info + if !strings.Contains(view, "3") { + t.Error("modal should show connection count") + } + + // modal should show format options + if !strings.Contains(view, "CSV") || !strings.Contains(view, "TSV") { + t.Error("modal should show format options") + } +} + diff --git a/internal/tui/symbols.go b/internal/tui/symbols.go index c1a00fd..961c17f 100644 --- a/internal/tui/symbols.go +++ b/internal/tui/symbols.go @@ -33,6 +33,8 @@ const ( BoxCross = string('\u253C') // light vertical and horizontal // misc - SymbolDash = string('\u2013') // en dash + SymbolDash = string('\u2013') // en dash + SymbolExport = string('\u21E5') // rightwards arrow to bar + SymbolPrompt = string('\u276F') // heavy right-pointing angle quotation mark ornament ) diff --git a/internal/tui/view.go b/internal/tui/view.go index 041c0e7..73db0d3 100644 --- a/internal/tui/view.go +++ b/internal/tui/view.go @@ -203,7 +203,7 @@ func (m model) renderStatusLine() string { return " " + m.theme.Styles.Warning.Render(m.statusMessage) } - left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state n/N dns w watch K kill s sort / search ? help q quit") + left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state n/N dns w watch K kill s sort / search x export ? help q quit") // show watched count if any if m.watchedCount() > 0 { @@ -271,6 +271,7 @@ func (m model) renderHelp() string { other ───── / search + x export to csv/tsv (enter filename) r refresh now q quit @@ -301,6 +302,8 @@ func (m model) renderDetail() string { value string }{ {"process", c.Process}, + {"cmdline", c.Cmdline}, + {"cwd", c.Cwd}, {"pid", fmt.Sprintf("%d", c.PID)}, {"user", c.User}, {"protocol", c.Proto}, @@ -368,6 +371,119 @@ func (m model) renderKillModal() string { return strings.Join(lines, "\n") } +func (m model) renderExportModal() string { + visible := m.visibleConnections() + + // count protocols and states for preview + tcpCount, udpCount := 0, 0 + listenCount, estabCount, otherCount := 0, 0, 0 + for _, c := range visible { + if c.Proto == "tcp" || c.Proto == "tcp6" { + tcpCount++ + } else { + udpCount++ + } + switch c.State { + case "LISTEN": + listenCount++ + case "ESTABLISHED": + estabCount++ + default: + otherCount++ + } + } + + var lines []string + + // header + lines = append(lines, "") + headerText := " " + SymbolExport + " EXPORT CONNECTIONS " + lines = append(lines, m.theme.Styles.Header.Render(headerText)) + lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 36))) + lines = append(lines, "") + + // stats preview section + lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" summary")) + lines = append(lines, fmt.Sprintf(" total: %s", + m.theme.Styles.Success.Render(fmt.Sprintf("%d connections", len(visible))))) + + protoSummary := fmt.Sprintf(" proto: %s tcp %s udp", + m.theme.Styles.GetProtoStyle("tcp").Render(fmt.Sprintf("%d", tcpCount)), + m.theme.Styles.GetProtoStyle("udp").Render(fmt.Sprintf("%d", udpCount))) + lines = append(lines, protoSummary) + + stateSummary := fmt.Sprintf(" state: %s listen %s estab %s other", + m.theme.Styles.GetStateStyle("LISTEN").Render(fmt.Sprintf("%d", listenCount)), + m.theme.Styles.GetStateStyle("ESTABLISHED").Render(fmt.Sprintf("%d", estabCount)), + m.theme.Styles.Normal.Render(fmt.Sprintf("%d", otherCount))) + lines = append(lines, stateSummary) + lines = append(lines, "") + + // format selection + lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" format")) + + csvStyle := m.theme.Styles.Normal + tsvStyle := m.theme.Styles.Normal + csvIndicator := " " + tsvIndicator := " " + + if m.exportFormat == "tsv" { + tsvStyle = m.theme.Styles.Success + tsvIndicator = m.theme.Styles.Success.Render(SymbolSelected + " ") + } else { + csvStyle = m.theme.Styles.Success + csvIndicator = m.theme.Styles.Success.Render(SymbolSelected + " ") + } + + formatLine := fmt.Sprintf(" %s%s %s%s", + csvIndicator, csvStyle.Render("CSV (comma)"), + tsvIndicator, tsvStyle.Render("TSV (tab)")) + lines = append(lines, formatLine) + lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 8)+" press "+m.theme.Styles.Warning.Render("tab")+" to toggle")) + lines = append(lines, "") + + // filename input + lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" filename")) + + ext := ".csv" + if m.exportFormat == "tsv" { + ext = ".tsv" + } + + filenameDisplay := m.exportFilename + if filenameDisplay == "" { + filenameDisplay = "connections" + } + + inputBox := fmt.Sprintf(" %s %s%s", + m.theme.Styles.Success.Render(SymbolPrompt), + m.theme.Styles.Warning.Render(filenameDisplay), + m.theme.Styles.Success.Render(ext+"▌")) + lines = append(lines, inputBox) + lines = append(lines, "") + + // error display + if m.exportError != "" { + lines = append(lines, m.theme.Styles.Error.Render(fmt.Sprintf(" %s %s", SymbolWarning, m.exportError))) + lines = append(lines, "") + } + + // preview of fields + lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 36))) + fieldsPreview := " fields: PID, PROCESS, USER, PROTO, STATE, LADDR, LPORT, RADDR, RPORT" + lines = append(lines, m.theme.Styles.Normal.Render(truncate(fieldsPreview, 40))) + lines = append(lines, "") + + // action buttons + lines = append(lines, fmt.Sprintf(" %s export %s toggle format %s cancel", + m.theme.Styles.Success.Render("[enter]"), + m.theme.Styles.Warning.Render("[tab]"), + m.theme.Styles.Error.Render("[esc]"))) + lines = append(lines, "") + + return strings.Join(lines, "\n") +} + func (m model) overlayModal(background, modal string) string { bgLines := strings.Split(background, "\n") modalLines := strings.Split(modal, "\n")