9 Commits

53 changed files with 4507 additions and 450 deletions

View File

@@ -32,9 +32,9 @@ jobs:
go-version: "1.25.0" go-version: "1.25.0"
- name: lint - name: lint
uses: golangci/golangci-lint-action@v6 uses: golangci/golangci-lint-action@v8
with: with:
version: latest version: v2.5.0
nix-build: nix-build:
strategy: strategy:
@@ -46,7 +46,10 @@ jobs:
- uses: DeterminateSystems/nix-installer-action@v17 - uses: DeterminateSystems/nix-installer-action@v17
- uses: DeterminateSystems/magic-nix-cache-action@v9 - uses: nix-community/cache-nix-action@v6
with:
primary-key: nix-${{ runner.os }}-${{ hashFiles('flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-
- name: nix flake check - name: nix flake check
run: nix flake check run: nix flake check

View File

@@ -82,6 +82,7 @@ aurs:
package: |- package: |-
install -Dm755 "./snitch" "${pkgdir}/usr/bin/snitch" install -Dm755 "./snitch" "${pkgdir}/usr/bin/snitch"
install -Dm644 "./LICENSE" "${pkgdir}/usr/share/licenses/snitch/LICENSE" install -Dm644 "./LICENSE" "${pkgdir}/usr/share/licenses/snitch/LICENSE"
install -Dm644 "./README.md" "${pkgdir}/usr/share/doc/snitch/README.md"
commit_msg_template: "Update to {{ .Tag }}" commit_msg_template: "Update to {{ .Tag }}"
skip_upload: auto skip_upload: auto

100
README.md
View File

@@ -6,13 +6,29 @@ a friendlier `ss` / `netstat` for humans. inspect network connections with a cle
## install ## install
### homebrew
```bash
brew install snitch
```
> thanks to [@bevanjkay](https://github.com/bevanjkay) for adding snitch to homebrew-core
### go ### go
```bash ```bash
go install github.com/karol-broda/snitch@latest go install github.com/karol-broda/snitch@latest
``` ```
### nixos / nix ### nixpkgs
```bash
nix-env -iA nixpkgs.snitch
```
> thanks to [@DieracDelta](https://github.com/DieracDelta) for adding snitch to nixpkgs
### nixos / nix (flake)
```bash ```bash
# try it # try it
@@ -28,6 +44,45 @@ nix profile install github:karol-broda/snitch
# then use: inputs.snitch.packages.${system}.default # then use: inputs.snitch.packages.${system}.default
``` ```
### home-manager (flake)
add snitch to your flake inputs and import the home-manager module:
```nix
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
home-manager.url = "github:nix-community/home-manager";
snitch.url = "github:karol-broda/snitch";
};
outputs = { nixpkgs, home-manager, snitch, ... }: {
homeConfigurations."user" = home-manager.lib.homeManagerConfiguration {
pkgs = nixpkgs.legacyPackages.x86_64-linux;
modules = [
snitch.homeManagerModules.default
{
programs.snitch = {
enable = true;
# optional: use the flake's package instead of nixpkgs
# package = snitch.packages.x86_64-linux.default;
settings = {
defaults = {
theme = "catppuccin-mocha";
interval = "2s";
resolve = true;
};
};
};
}
];
};
};
}
```
available themes: `ansi`, `catppuccin-mocha`, `catppuccin-macchiato`, `catppuccin-frappe`, `catppuccin-latte`, `gruvbox-dark`, `gruvbox-light`, `dracula`, `nord`, `tokyo-night`, `tokyo-night-storm`, `tokyo-night-light`, `solarized-dark`, `solarized-light`, `one-dark`, `mono`
### arch linux (aur) ### arch linux (aur)
```bash ```bash
@@ -167,9 +222,20 @@ shortcut flags work on all commands:
-e, --established established connections -e, --established established connections
-4, --ipv4 ipv4 only -4, --ipv4 ipv4 only
-6, --ipv6 ipv6 only -6, --ipv6 ipv6 only
-n, --numeric no dns resolution
``` ```
## resolution
dns and service name resolution options:
```
--resolve-addrs resolve ip addresses to hostnames (default: true)
--resolve-ports resolve port numbers to service names
--no-cache disable dns caching (force fresh lookups)
```
dns lookups are performed in parallel and cached for performance. use `--no-cache` to bypass the cache for debugging or when addresses change frequently.
for more specific filtering, use `key=value` syntax with `ls`: for more specific filtering, use `key=value` syntax with `ls`:
```bash ```bash
@@ -208,8 +274,34 @@ optional config file at `~/.config/snitch/snitch.toml`:
```toml ```toml
[defaults] [defaults]
numeric = false numeric = false # disable name resolution
theme = "auto" dns_cache = true # cache dns lookups (set to false to disable)
theme = "auto" # color theme: auto, dark, light, mono
[tui]
remember_state = false # remember view options between sessions
```
### remembering view options
when `remember_state = true`, the tui will save and restore:
- filter toggles (tcp/udp, listen/established/other)
- sort field and direction
- address and port resolution settings
state is saved to `$XDG_STATE_HOME/snitch/tui.json` (defaults to `~/.local/state/snitch/tui.json`).
cli flags always take priority over saved state.
### environment variables
```bash
SNITCH_THEME=dark # set default theme
SNITCH_RESOLVE=0 # disable dns resolution
SNITCH_DNS_CACHE=0 # disable dns caching
SNITCH_NO_COLOR=1 # disable color output
SNITCH_CONFIG=/path/to # custom config file path
``` ```
## requirements ## requirements

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/karol-broda/snitch/internal/errutil"
"github.com/karol-broda/snitch/internal/testutil" "github.com/karol-broda/snitch/internal/testutil"
) )
@@ -407,16 +408,16 @@ func TestEnvironmentVariables(t *testing.T) {
oldEnvVars := make(map[string]string) oldEnvVars := make(map[string]string)
for key, value := range tt.envVars { for key, value := range tt.envVars {
oldEnvVars[key] = os.Getenv(key) oldEnvVars[key] = os.Getenv(key)
os.Setenv(key, value) errutil.Setenv(key, value)
} }
// Clean up environment variables // Clean up environment variables
defer func() { defer func() {
for key, oldValue := range oldEnvVars { for key, oldValue := range oldEnvVars {
if oldValue == "" { if oldValue == "" {
os.Unsetenv(key) errutil.Unsetenv(key)
} else { } else {
os.Setenv(key, oldValue) errutil.Setenv(key, oldValue)
} }
} }
}() }()

View File

@@ -8,16 +8,18 @@ import (
"log" "log"
"os" "os"
"os/exec" "os/exec"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/color"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/resolver"
"strconv" "strconv"
"strings" "strings"
"text/tabwriter" "text/tabwriter"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/color"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/errutil"
"github.com/karol-broda/snitch/internal/resolver"
"github.com/tidwall/pretty" "github.com/tidwall/pretty"
"golang.org/x/term" "golang.org/x/term"
) )
@@ -25,13 +27,12 @@ import (
// ls-specific flags // ls-specific flags
var ( var (
outputFormat string outputFormat string
outputFile string
noHeaders bool noHeaders bool
showTimestamp bool showTimestamp bool
sortBy string sortBy string
fields string fields string
colorMode string colorMode string
resolveAddrs bool
resolvePorts bool
plainOutput bool plainOutput bool
) )
@@ -72,9 +73,77 @@ func runListCommand(outputFormat string, args []string) {
selectedFields = strings.Split(fields, ",") selectedFields = strings.Split(fields, ",")
} }
// handle file output
if outputFile != "" {
writeToFile(rt.Connections, outputFile, selectedFields)
return
}
renderList(rt.Connections, outputFormat, selectedFields) renderList(rt.Connections, outputFormat, selectedFields)
} }
func writeToFile(connections []collector.Connection, filename string, selectedFields []string) {
file, err := os.Create(filename)
if err != nil {
log.Fatalf("failed to create file: %v", err)
}
defer errutil.Close(file)
// determine format from extension
format := "csv"
lowerFilename := strings.ToLower(filename)
if strings.HasSuffix(lowerFilename, ".json") {
format = "json"
} else if strings.HasSuffix(lowerFilename, ".tsv") {
format = "tsv"
}
if len(selectedFields) == 0 {
selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}
if showTimestamp {
selectedFields = append([]string{"ts"}, selectedFields...)
}
}
switch format {
case "json":
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(connections); err != nil {
log.Fatalf("failed to write JSON: %v", err)
}
case "tsv":
writeDelimited(file, connections, "\t", !noHeaders, selectedFields)
default:
writeDelimited(file, connections, ",", !noHeaders, selectedFields)
}
fmt.Fprintf(os.Stderr, "exported %d connections to %s\n", len(connections), filename)
}
func writeDelimited(w io.Writer, connections []collector.Connection, delimiter string, headers bool, selectedFields []string) {
if headers {
headerRow := make([]string, len(selectedFields))
for i, field := range selectedFields {
headerRow[i] = strings.ToUpper(field)
}
_, _ = fmt.Fprintln(w, strings.Join(headerRow, delimiter))
}
for _, conn := range connections {
fieldMap := getFieldMap(conn)
row := make([]string, len(selectedFields))
for i, field := range selectedFields {
val := fieldMap[field]
if delimiter == "," && (strings.Contains(val, ",") || strings.Contains(val, "\"") || strings.Contains(val, "\n")) {
val = "\"" + strings.ReplaceAll(val, "\"", "\"\"") + "\""
}
row[i] = val
}
_, _ = fmt.Fprintln(w, strings.Join(row, delimiter))
}
}
func renderList(connections []collector.Connection, format string, selectedFields []string) { func renderList(connections []collector.Connection, format string, selectedFields []string) {
switch format { switch format {
case "json": case "json":
@@ -122,6 +191,8 @@ func getFieldMap(c collector.Connection) map[string]string {
return map[string]string{ return map[string]string{
"pid": strconv.Itoa(c.PID), "pid": strconv.Itoa(c.PID),
"process": c.Process, "process": c.Process,
"cmdline": c.Cmdline,
"cwd": c.Cwd,
"user": c.User, "user": c.User,
"uid": strconv.Itoa(c.UID), "uid": strconv.Itoa(c.UID),
"proto": c.Proto, "proto": c.Proto,
@@ -187,7 +258,7 @@ func printCSV(conns []collector.Connection, headers bool, timestamp bool, select
func printPlainTable(conns []collector.Connection, headers bool, timestamp bool, selectedFields []string) { func printPlainTable(conns []collector.Connection, headers bool, timestamp bool, selectedFields []string) {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
defer w.Flush() defer errutil.Flush(w)
if len(selectedFields) == 0 { if len(selectedFields) == 0 {
selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"} selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}
@@ -201,7 +272,7 @@ func printPlainTable(conns []collector.Connection, headers bool, timestamp bool,
for _, field := range selectedFields { for _, field := range selectedFields {
headerRow = append(headerRow, strings.ToUpper(field)) headerRow = append(headerRow, strings.ToUpper(field))
} }
fmt.Fprintln(w, strings.Join(headerRow, "\t")) errutil.Ignore(fmt.Fprintln(w, strings.Join(headerRow, "\t")))
} }
for _, conn := range conns { for _, conn := range conns {
@@ -210,7 +281,7 @@ func printPlainTable(conns []collector.Connection, headers bool, timestamp bool,
for _, field := range selectedFields { for _, field := range selectedFields {
row = append(row, fieldMap[field]) row = append(row, fieldMap[field])
} }
fmt.Fprintln(w, strings.Join(row, "\t")) errutil.Ignore(fmt.Fprintln(w, strings.Join(row, "\t")))
} }
} }
@@ -395,15 +466,15 @@ func init() {
// ls-specific flags // ls-specific flags
lsCmd.Flags().StringVarP(&outputFormat, "output", "o", cfg.Defaults.OutputFormat, "Output format (table, wide, json, csv)") lsCmd.Flags().StringVarP(&outputFormat, "output", "o", cfg.Defaults.OutputFormat, "Output format (table, wide, json, csv)")
lsCmd.Flags().StringVarP(&outputFile, "output-file", "O", "", "Write output to file (format detected from extension: .csv, .tsv, .json)")
lsCmd.Flags().BoolVar(&noHeaders, "no-headers", cfg.Defaults.NoHeaders, "Omit headers for table/csv output") lsCmd.Flags().BoolVar(&noHeaders, "no-headers", cfg.Defaults.NoHeaders, "Omit headers for table/csv output")
lsCmd.Flags().BoolVar(&showTimestamp, "ts", false, "Include timestamp in output") lsCmd.Flags().BoolVar(&showTimestamp, "ts", false, "Include timestamp in output")
lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)") lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)")
lsCmd.Flags().StringVarP(&fields, "fields", "f", strings.Join(cfg.Defaults.Fields, ","), "Comma-separated list of fields to show") lsCmd.Flags().StringVarP(&fields, "fields", "f", strings.Join(cfg.Defaults.Fields, ","), "Comma-separated list of fields to show")
lsCmd.Flags().StringVar(&colorMode, "color", cfg.Defaults.Color, "Color mode (auto, always, never)") lsCmd.Flags().StringVar(&colorMode, "color", cfg.Defaults.Color, "Color mode (auto, always, never)")
lsCmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
lsCmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
lsCmd.Flags().BoolVarP(&plainOutput, "plain", "p", false, "Plain output (parsable, no styling)") lsCmd.Flags().BoolVarP(&plainOutput, "plain", "p", false, "Plain output (parsable, no styling)")
// shared filter flags // shared flags
addFilterFlags(lsCmd) addFilterFlags(lsCmd)
addResolutionFlags(lsCmd)
} }

View File

@@ -3,11 +3,12 @@ package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/config"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
cfgFile string cfgFile string
) )
@@ -42,11 +43,10 @@ func init() {
// add top's flags to root so `snitch -l` works (defaults to top command) // add top's flags to root so `snitch -l` works (defaults to top command)
cfg := config.Get() cfg := config.Get()
rootCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (dark, light, mono, auto)") rootCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
rootCmd.Flags().DurationVarP(&topInterval, "interval", "i", 0, "Refresh interval (default 1s)") rootCmd.Flags().DurationVarP(&topInterval, "interval", "i", 0, "Refresh interval (default 1s)")
rootCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
rootCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
// shared filter flags for root command // shared flags for root command
addFilterFlags(rootCmd) addFilterFlags(rootCmd)
addResolutionFlags(rootCmd)
} }

View File

@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"github.com/karol-broda/snitch/internal/collector" "github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/color" "github.com/karol-broda/snitch/internal/color"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/resolver"
"strconv" "strconv"
"strings" "strings"
@@ -11,7 +13,7 @@ import (
) )
// Runtime holds the shared state for all commands. // Runtime holds the shared state for all commands.
// it handles common filter logic, fetching, and filtering connections. // it handles common filter logic, fetching, filtering, and resolution.
type Runtime struct { type Runtime struct {
// filter options built from flags and args // filter options built from flags and args
Filters collector.FilterOptions Filters collector.FilterOptions
@@ -23,6 +25,7 @@ type Runtime struct {
ColorMode string ColorMode string
ResolveAddrs bool ResolveAddrs bool
ResolvePorts bool ResolvePorts bool
NoCache bool
} }
// shared filter flags - used by all commands // shared filter flags - used by all commands
@@ -35,6 +38,13 @@ var (
filterIPv6 bool filterIPv6 bool
) )
// shared resolution flags - used by all commands
var (
resolveAddrs bool
resolvePorts bool
noCache bool
)
// BuildFilters constructs FilterOptions from command args and shortcut flags. // BuildFilters constructs FilterOptions from command args and shortcut flags.
func BuildFilters(args []string) (collector.FilterOptions, error) { func BuildFilters(args []string) (collector.FilterOptions, error) {
filters, err := ParseFilterArgs(args) filters, err := ParseFilterArgs(args)
@@ -77,6 +87,12 @@ func FetchConnections(filters collector.FilterOptions) ([]collector.Connection,
func NewRuntime(args []string, colorMode string) (*Runtime, error) { func NewRuntime(args []string, colorMode string) (*Runtime, error) {
color.Init(colorMode) color.Init(colorMode)
cfg := config.Get()
// configure resolver with cache setting (flag overrides config)
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
resolver.SetNoCache(effectiveNoCache)
filters, err := BuildFilters(args) filters, err := BuildFilters(args)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse filters: %w", err) return nil, fmt.Errorf("failed to parse filters: %w", err)
@@ -87,13 +103,30 @@ func NewRuntime(args []string, colorMode string) (*Runtime, error) {
return nil, fmt.Errorf("failed to fetch connections: %w", err) return nil, fmt.Errorf("failed to fetch connections: %w", err)
} }
return &Runtime{ rt := &Runtime{
Filters: filters, Filters: filters,
Connections: connections, Connections: connections,
ColorMode: colorMode, ColorMode: colorMode,
ResolveAddrs: resolveAddrs, ResolveAddrs: resolveAddrs,
ResolvePorts: resolvePorts, ResolvePorts: resolvePorts,
}, nil NoCache: effectiveNoCache,
}
// pre-warm dns cache by resolving all addresses in parallel
if resolveAddrs {
rt.PreWarmDNS()
}
return rt, nil
}
// PreWarmDNS resolves all connection addresses in parallel to warm the cache.
func (r *Runtime) PreWarmDNS() {
addrs := make([]string, 0, len(r.Connections)*2)
for _, c := range r.Connections {
addrs = append(addrs, c.Laddr, c.Raddr)
}
resolver.ResolveAddrsParallel(addrs)
} }
// SortConnections sorts the runtime's connections in place. // SortConnections sorts the runtime's connections in place.
@@ -201,3 +234,11 @@ func addFilterFlags(cmd *cobra.Command) {
cmd.Flags().BoolVarP(&filterIPv6, "ipv6", "6", false, "Only show IPv6 connections") cmd.Flags().BoolVarP(&filterIPv6, "ipv6", "6", false, "Only show IPv6 connections")
} }
// addResolutionFlags adds the common resolution flags to a command.
func addResolutionFlags(cmd *cobra.Command) {
cfg := config.Get()
cmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
cmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
cmd.Flags().BoolVar(&noCache, "no-cache", !cfg.Defaults.DNSCache, "Disable DNS caching (force fresh lookups)")
}

525
cmd/runtime_test.go Normal file
View File

@@ -0,0 +1,525 @@
package cmd
import (
"testing"
"github.com/karol-broda/snitch/internal/collector"
)
func TestParseFilterArgs_Empty(t *testing.T) {
filters, err := ParseFilterArgs([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "" {
t.Errorf("expected empty proto, got %q", filters.Proto)
}
}
func TestParseFilterArgs_Proto(t *testing.T) {
filters, err := ParseFilterArgs([]string{"proto=tcp"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "tcp" {
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
}
}
func TestParseFilterArgs_State(t *testing.T) {
filters, err := ParseFilterArgs([]string{"state=LISTEN"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.State != "LISTEN" {
t.Errorf("expected state 'LISTEN', got %q", filters.State)
}
}
func TestParseFilterArgs_PID(t *testing.T) {
filters, err := ParseFilterArgs([]string{"pid=1234"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Pid != 1234 {
t.Errorf("expected pid 1234, got %d", filters.Pid)
}
}
func TestParseFilterArgs_InvalidPID(t *testing.T) {
_, err := ParseFilterArgs([]string{"pid=notanumber"})
if err == nil {
t.Error("expected error for invalid pid")
}
}
func TestParseFilterArgs_Proc(t *testing.T) {
filters, err := ParseFilterArgs([]string{"proc=nginx"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proc != "nginx" {
t.Errorf("expected proc 'nginx', got %q", filters.Proc)
}
}
func TestParseFilterArgs_Lport(t *testing.T) {
filters, err := ParseFilterArgs([]string{"lport=80"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Lport != 80 {
t.Errorf("expected lport 80, got %d", filters.Lport)
}
}
func TestParseFilterArgs_InvalidLport(t *testing.T) {
_, err := ParseFilterArgs([]string{"lport=notaport"})
if err == nil {
t.Error("expected error for invalid lport")
}
}
func TestParseFilterArgs_Rport(t *testing.T) {
filters, err := ParseFilterArgs([]string{"rport=443"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Rport != 443 {
t.Errorf("expected rport 443, got %d", filters.Rport)
}
}
func TestParseFilterArgs_InvalidRport(t *testing.T) {
_, err := ParseFilterArgs([]string{"rport=invalid"})
if err == nil {
t.Error("expected error for invalid rport")
}
}
func TestParseFilterArgs_UserByName(t *testing.T) {
filters, err := ParseFilterArgs([]string{"user=root"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.User != "root" {
t.Errorf("expected user 'root', got %q", filters.User)
}
}
func TestParseFilterArgs_UserByUID(t *testing.T) {
filters, err := ParseFilterArgs([]string{"user=1000"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.UID != 1000 {
t.Errorf("expected uid 1000, got %d", filters.UID)
}
}
func TestParseFilterArgs_Laddr(t *testing.T) {
filters, err := ParseFilterArgs([]string{"laddr=127.0.0.1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Laddr != "127.0.0.1" {
t.Errorf("expected laddr '127.0.0.1', got %q", filters.Laddr)
}
}
func TestParseFilterArgs_Raddr(t *testing.T) {
filters, err := ParseFilterArgs([]string{"raddr=8.8.8.8"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Raddr != "8.8.8.8" {
t.Errorf("expected raddr '8.8.8.8', got %q", filters.Raddr)
}
}
func TestParseFilterArgs_Contains(t *testing.T) {
filters, err := ParseFilterArgs([]string{"contains=google"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Contains != "google" {
t.Errorf("expected contains 'google', got %q", filters.Contains)
}
}
func TestParseFilterArgs_Interface(t *testing.T) {
filters, err := ParseFilterArgs([]string{"if=eth0"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Interface != "eth0" {
t.Errorf("expected interface 'eth0', got %q", filters.Interface)
}
// test alternative syntax
filters2, err := ParseFilterArgs([]string{"interface=lo"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters2.Interface != "lo" {
t.Errorf("expected interface 'lo', got %q", filters2.Interface)
}
}
func TestParseFilterArgs_Mark(t *testing.T) {
filters, err := ParseFilterArgs([]string{"mark=0x1234"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Mark != "0x1234" {
t.Errorf("expected mark '0x1234', got %q", filters.Mark)
}
}
func TestParseFilterArgs_Namespace(t *testing.T) {
filters, err := ParseFilterArgs([]string{"namespace=default"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Namespace != "default" {
t.Errorf("expected namespace 'default', got %q", filters.Namespace)
}
}
func TestParseFilterArgs_Inode(t *testing.T) {
filters, err := ParseFilterArgs([]string{"inode=123456"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Inode != 123456 {
t.Errorf("expected inode 123456, got %d", filters.Inode)
}
}
func TestParseFilterArgs_InvalidInode(t *testing.T) {
_, err := ParseFilterArgs([]string{"inode=notanumber"})
if err == nil {
t.Error("expected error for invalid inode")
}
}
func TestParseFilterArgs_Multiple(t *testing.T) {
filters, err := ParseFilterArgs([]string{"proto=tcp", "state=LISTEN", "lport=80"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "tcp" {
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
}
if filters.State != "LISTEN" {
t.Errorf("expected state 'LISTEN', got %q", filters.State)
}
if filters.Lport != 80 {
t.Errorf("expected lport 80, got %d", filters.Lport)
}
}
func TestParseFilterArgs_InvalidFormat(t *testing.T) {
_, err := ParseFilterArgs([]string{"invalidformat"})
if err == nil {
t.Error("expected error for invalid format")
}
}
func TestParseFilterArgs_UnknownKey(t *testing.T) {
_, err := ParseFilterArgs([]string{"unknownkey=value"})
if err == nil {
t.Error("expected error for unknown key")
}
}
func TestParseFilterArgs_CaseInsensitiveKeys(t *testing.T) {
filters, err := ParseFilterArgs([]string{"PROTO=tcp", "State=LISTEN"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "tcp" {
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
}
if filters.State != "LISTEN" {
t.Errorf("expected state 'LISTEN', got %q", filters.State)
}
}
func TestBuildFilters_TCPOnly(t *testing.T) {
// save and restore global flags
oldTCP, oldUDP := filterTCP, filterUDP
defer func() {
filterTCP, filterUDP = oldTCP, oldUDP
}()
filterTCP = true
filterUDP = false
filters, err := BuildFilters([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "tcp" {
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
}
}
func TestBuildFilters_UDPOnly(t *testing.T) {
oldTCP, oldUDP := filterTCP, filterUDP
defer func() {
filterTCP, filterUDP = oldTCP, oldUDP
}()
filterTCP = false
filterUDP = true
filters, err := BuildFilters([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "udp" {
t.Errorf("expected proto 'udp', got %q", filters.Proto)
}
}
func TestBuildFilters_ListenOnly(t *testing.T) {
oldListen, oldEstab := filterListen, filterEstab
defer func() {
filterListen, filterEstab = oldListen, oldEstab
}()
filterListen = true
filterEstab = false
filters, err := BuildFilters([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.State != "LISTEN" {
t.Errorf("expected state 'LISTEN', got %q", filters.State)
}
}
func TestBuildFilters_EstablishedOnly(t *testing.T) {
oldListen, oldEstab := filterListen, filterEstab
defer func() {
filterListen, filterEstab = oldListen, oldEstab
}()
filterListen = false
filterEstab = true
filters, err := BuildFilters([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.State != "ESTABLISHED" {
t.Errorf("expected state 'ESTABLISHED', got %q", filters.State)
}
}
func TestBuildFilters_IPv4Flag(t *testing.T) {
oldIPv4 := filterIPv4
defer func() {
filterIPv4 = oldIPv4
}()
filterIPv4 = true
filters, err := BuildFilters([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !filters.IPv4 {
t.Error("expected IPv4 to be true")
}
}
func TestBuildFilters_IPv6Flag(t *testing.T) {
oldIPv6 := filterIPv6
defer func() {
filterIPv6 = oldIPv6
}()
filterIPv6 = true
filters, err := BuildFilters([]string{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !filters.IPv6 {
t.Error("expected IPv6 to be true")
}
}
func TestBuildFilters_CombinedArgsAndFlags(t *testing.T) {
oldTCP := filterTCP
defer func() {
filterTCP = oldTCP
}()
filterTCP = true
filters, err := BuildFilters([]string{"lport=80"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filters.Proto != "tcp" {
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
}
if filters.Lport != 80 {
t.Errorf("expected lport 80, got %d", filters.Lport)
}
}
func TestRuntime_PreWarmDNS(t *testing.T) {
rt := &Runtime{
Connections: []collector.Connection{
{Laddr: "127.0.0.1", Raddr: "192.168.1.1"},
{Laddr: "127.0.0.1", Raddr: "10.0.0.1"},
},
}
// should not panic
rt.PreWarmDNS()
}
func TestRuntime_PreWarmDNS_Empty(t *testing.T) {
rt := &Runtime{
Connections: []collector.Connection{},
}
// should not panic with empty connections
rt.PreWarmDNS()
}
func TestRuntime_SortConnections(t *testing.T) {
rt := &Runtime{
Connections: []collector.Connection{
{Lport: 443},
{Lport: 80},
{Lport: 8080},
},
}
rt.SortConnections(collector.SortOptions{
Field: collector.SortByLport,
Direction: collector.SortAsc,
})
if rt.Connections[0].Lport != 80 {
t.Errorf("expected first connection to have lport 80, got %d", rt.Connections[0].Lport)
}
if rt.Connections[1].Lport != 443 {
t.Errorf("expected second connection to have lport 443, got %d", rt.Connections[1].Lport)
}
if rt.Connections[2].Lport != 8080 {
t.Errorf("expected third connection to have lport 8080, got %d", rt.Connections[2].Lport)
}
}
func TestRuntime_SortConnections_Desc(t *testing.T) {
rt := &Runtime{
Connections: []collector.Connection{
{Lport: 80},
{Lport: 443},
{Lport: 8080},
},
}
rt.SortConnections(collector.SortOptions{
Field: collector.SortByLport,
Direction: collector.SortDesc,
})
if rt.Connections[0].Lport != 8080 {
t.Errorf("expected first connection to have lport 8080, got %d", rt.Connections[0].Lport)
}
}
func TestApplyFilter_AllKeys(t *testing.T) {
tests := []struct {
key string
value string
validate func(t *testing.T, f *collector.FilterOptions)
}{
{"proto", "tcp", func(t *testing.T, f *collector.FilterOptions) {
if f.Proto != "tcp" {
t.Errorf("proto: expected 'tcp', got %q", f.Proto)
}
}},
{"state", "LISTEN", func(t *testing.T, f *collector.FilterOptions) {
if f.State != "LISTEN" {
t.Errorf("state: expected 'LISTEN', got %q", f.State)
}
}},
{"pid", "100", func(t *testing.T, f *collector.FilterOptions) {
if f.Pid != 100 {
t.Errorf("pid: expected 100, got %d", f.Pid)
}
}},
{"proc", "nginx", func(t *testing.T, f *collector.FilterOptions) {
if f.Proc != "nginx" {
t.Errorf("proc: expected 'nginx', got %q", f.Proc)
}
}},
{"lport", "80", func(t *testing.T, f *collector.FilterOptions) {
if f.Lport != 80 {
t.Errorf("lport: expected 80, got %d", f.Lport)
}
}},
{"rport", "443", func(t *testing.T, f *collector.FilterOptions) {
if f.Rport != 443 {
t.Errorf("rport: expected 443, got %d", f.Rport)
}
}},
{"laddr", "127.0.0.1", func(t *testing.T, f *collector.FilterOptions) {
if f.Laddr != "127.0.0.1" {
t.Errorf("laddr: expected '127.0.0.1', got %q", f.Laddr)
}
}},
{"raddr", "8.8.8.8", func(t *testing.T, f *collector.FilterOptions) {
if f.Raddr != "8.8.8.8" {
t.Errorf("raddr: expected '8.8.8.8', got %q", f.Raddr)
}
}},
{"contains", "test", func(t *testing.T, f *collector.FilterOptions) {
if f.Contains != "test" {
t.Errorf("contains: expected 'test', got %q", f.Contains)
}
}},
{"if", "eth0", func(t *testing.T, f *collector.FilterOptions) {
if f.Interface != "eth0" {
t.Errorf("interface: expected 'eth0', got %q", f.Interface)
}
}},
{"mark", "0xff", func(t *testing.T, f *collector.FilterOptions) {
if f.Mark != "0xff" {
t.Errorf("mark: expected '0xff', got %q", f.Mark)
}
}},
{"namespace", "ns1", func(t *testing.T, f *collector.FilterOptions) {
if f.Namespace != "ns1" {
t.Errorf("namespace: expected 'ns1', got %q", f.Namespace)
}
}},
{"inode", "12345", func(t *testing.T, f *collector.FilterOptions) {
if f.Inode != 12345 {
t.Errorf("inode: expected 12345, got %d", f.Inode)
}
}},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
filters := &collector.FilterOptions{}
err := applyFilter(filters, tt.key, tt.value)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tt.validate(t, filters)
})
}
}

View File

@@ -8,7 +8,6 @@ import (
"log" "log"
"os" "os"
"os/signal" "os/signal"
"github.com/karol-broda/snitch/internal/collector"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -17,6 +16,9 @@ import (
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/errutil"
) )
type StatsData struct { type StatsData struct {
@@ -227,19 +229,19 @@ func printStatsCSV(stats *StatsData, headers bool) {
func printStatsTable(stats *StatsData, headers bool) { func printStatsTable(stats *StatsData, headers bool) {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
defer w.Flush() defer errutil.Flush(w)
if headers { if headers {
fmt.Fprintf(w, "TIMESTAMP\t%s\n", stats.Timestamp.Format(time.RFC3339)) errutil.Ignore(fmt.Fprintf(w, "TIMESTAMP\t%s\n", stats.Timestamp.Format(time.RFC3339)))
fmt.Fprintf(w, "TOTAL CONNECTIONS\t%d\n", stats.Total) errutil.Ignore(fmt.Fprintf(w, "TOTAL CONNECTIONS\t%d\n", stats.Total))
fmt.Fprintln(w) errutil.Ignore(fmt.Fprintln(w))
} }
// Protocol breakdown // Protocol breakdown
if len(stats.ByProto) > 0 { if len(stats.ByProto) > 0 {
if headers { if headers {
fmt.Fprintln(w, "BY PROTOCOL:") errutil.Ignore(fmt.Fprintln(w, "BY PROTOCOL:"))
fmt.Fprintln(w, "PROTO\tCOUNT") errutil.Ignore(fmt.Fprintln(w, "PROTO\tCOUNT"))
} }
protocols := make([]string, 0, len(stats.ByProto)) protocols := make([]string, 0, len(stats.ByProto))
for proto := range stats.ByProto { for proto := range stats.ByProto {
@@ -247,16 +249,16 @@ func printStatsTable(stats *StatsData, headers bool) {
} }
sort.Strings(protocols) sort.Strings(protocols)
for _, proto := range protocols { for _, proto := range protocols {
fmt.Fprintf(w, "%s\t%d\n", strings.ToUpper(proto), stats.ByProto[proto]) errutil.Ignore(fmt.Fprintf(w, "%s\t%d\n", strings.ToUpper(proto), stats.ByProto[proto]))
} }
fmt.Fprintln(w) errutil.Ignore(fmt.Fprintln(w))
} }
// State breakdown // State breakdown
if len(stats.ByState) > 0 { if len(stats.ByState) > 0 {
if headers { if headers {
fmt.Fprintln(w, "BY STATE:") errutil.Ignore(fmt.Fprintln(w, "BY STATE:"))
fmt.Fprintln(w, "STATE\tCOUNT") errutil.Ignore(fmt.Fprintln(w, "STATE\tCOUNT"))
} }
states := make([]string, 0, len(stats.ByState)) states := make([]string, 0, len(stats.ByState))
for state := range stats.ByState { for state := range stats.ByState {
@@ -264,16 +266,16 @@ func printStatsTable(stats *StatsData, headers bool) {
} }
sort.Strings(states) sort.Strings(states)
for _, state := range states { for _, state := range states {
fmt.Fprintf(w, "%s\t%d\n", state, stats.ByState[state]) errutil.Ignore(fmt.Fprintf(w, "%s\t%d\n", state, stats.ByState[state]))
} }
fmt.Fprintln(w) errutil.Ignore(fmt.Fprintln(w))
} }
// Process breakdown (top 10) // Process breakdown (top 10)
if len(stats.ByProc) > 0 { if len(stats.ByProc) > 0 {
if headers { if headers {
fmt.Fprintln(w, "BY PROCESS (TOP 10):") errutil.Ignore(fmt.Fprintln(w, "BY PROCESS (TOP 10):"))
fmt.Fprintln(w, "PID\tPROCESS\tCOUNT") errutil.Ignore(fmt.Fprintln(w, "PID\tPROCESS\tCOUNT"))
} }
limit := 10 limit := 10
if len(stats.ByProc) < limit { if len(stats.ByProc) < limit {
@@ -281,7 +283,7 @@ func printStatsTable(stats *StatsData, headers bool) {
} }
for i := 0; i < limit; i++ { for i := 0; i < limit; i++ {
proc := stats.ByProc[i] proc := stats.ByProc[i]
fmt.Fprintf(w, "%d\t%s\t%d\n", proc.PID, proc.Process, proc.Count) errutil.Ignore(fmt.Fprintf(w, "%d\t%s\t%d\n", proc.PID, proc.Process, proc.Count))
} }
} }
} }

24
cmd/themes.go Normal file
View File

@@ -0,0 +1,24 @@
package cmd
import (
"fmt"
"github.com/karol-broda/snitch/internal/theme"
"github.com/spf13/cobra"
)
var themesCmd = &cobra.Command{
Use: "themes",
Short: "List available themes",
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("Available themes (default: %s):\n\n", theme.DefaultTheme)
for _, name := range theme.ListThemes() {
fmt.Printf(" %s\n", name)
}
},
}
func init() {
rootCmd.AddCommand(themesCmd)
}

View File

@@ -2,20 +2,19 @@ package cmd
import ( import (
"log" "log"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/tui"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/resolver"
"github.com/karol-broda/snitch/internal/tui"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// top-specific flags // top-specific flags
var ( var (
topTheme string topTheme string
topInterval time.Duration topInterval time.Duration
topResolveAddrs bool
topResolvePorts bool
) )
var topCmd = &cobra.Command{ var topCmd = &cobra.Command{
@@ -29,11 +28,17 @@ var topCmd = &cobra.Command{
theme = cfg.Defaults.Theme theme = cfg.Defaults.Theme
} }
// configure resolver with cache setting
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
resolver.SetNoCache(effectiveNoCache)
opts := tui.Options{ opts := tui.Options{
Theme: theme, Theme: theme,
Interval: topInterval, Interval: topInterval,
ResolveAddrs: topResolveAddrs, ResolveAddrs: resolveAddrs,
ResolvePorts: topResolvePorts, ResolvePorts: resolvePorts,
NoCache: effectiveNoCache,
RememberState: cfg.TUI.RememberState,
} }
// if any filter flag is set, use exclusive mode // if any filter flag is set, use exclusive mode
@@ -60,11 +65,10 @@ func init() {
cfg := config.Get() cfg := config.Get()
// top-specific flags // top-specific flags
topCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (dark, light, mono, auto)") topCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
topCmd.Flags().DurationVarP(&topInterval, "interval", "i", time.Second, "Refresh interval") topCmd.Flags().DurationVarP(&topInterval, "interval", "i", time.Second, "Refresh interval")
topCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
topCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
// shared filter flags // shared flags
addFilterFlags(topCmd) addFilterFlags(topCmd)
addResolutionFlags(topCmd)
} }

View File

@@ -7,12 +7,14 @@ import (
"log" "log"
"os" "os"
"os/signal" "os/signal"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/resolver"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/resolver"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -23,11 +25,10 @@ type TraceEvent struct {
} }
var ( var (
traceInterval time.Duration traceInterval time.Duration
traceCount int traceCount int
traceOutputFormat string traceOutputFormat string
traceNumeric bool traceTimestamp bool
traceTimestamp bool
) )
var traceCmd = &cobra.Command{ var traceCmd = &cobra.Command{
@@ -47,6 +48,12 @@ Available filters:
} }
func runTraceCommand(args []string) { func runTraceCommand(args []string) {
cfg := config.Get()
// configure resolver with cache setting
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
resolver.SetNoCache(effectiveNoCache)
filters, err := BuildFilters(args) filters, err := BuildFilters(args)
if err != nil { if err != nil {
log.Fatalf("Error parsing filters: %v", err) log.Fatalf("Error parsing filters: %v", err)
@@ -180,14 +187,16 @@ func printTraceEventHuman(event TraceEvent) {
lportStr := fmt.Sprintf("%d", conn.Lport) lportStr := fmt.Sprintf("%d", conn.Lport)
rportStr := fmt.Sprintf("%d", conn.Rport) rportStr := fmt.Sprintf("%d", conn.Rport)
// Handle name resolution based on numeric flag // apply name resolution
if !traceNumeric { if resolveAddrs {
if resolvedLaddr := resolver.ResolveAddr(conn.Laddr); resolvedLaddr != conn.Laddr { if resolvedLaddr := resolver.ResolveAddr(conn.Laddr); resolvedLaddr != conn.Laddr {
laddr = resolvedLaddr laddr = resolvedLaddr
} }
if resolvedRaddr := resolver.ResolveAddr(conn.Raddr); resolvedRaddr != conn.Raddr && conn.Raddr != "*" && conn.Raddr != "" { if resolvedRaddr := resolver.ResolveAddr(conn.Raddr); resolvedRaddr != conn.Raddr && conn.Raddr != "*" && conn.Raddr != "" {
raddr = resolvedRaddr raddr = resolvedRaddr
} }
}
if resolvePorts {
if resolvedLport := resolver.ResolvePort(conn.Lport, conn.Proto); resolvedLport != fmt.Sprintf("%d", conn.Lport) { if resolvedLport := resolver.ResolvePort(conn.Lport, conn.Proto); resolvedLport != fmt.Sprintf("%d", conn.Lport) {
lportStr = resolvedLport lportStr = resolvedLport
} }
@@ -225,9 +234,9 @@ func init() {
traceCmd.Flags().DurationVarP(&traceInterval, "interval", "i", time.Second, "Polling interval (e.g., 500ms, 2s)") traceCmd.Flags().DurationVarP(&traceInterval, "interval", "i", time.Second, "Polling interval (e.g., 500ms, 2s)")
traceCmd.Flags().IntVarP(&traceCount, "count", "c", 0, "Number of events to capture (0 = unlimited)") traceCmd.Flags().IntVarP(&traceCount, "count", "c", 0, "Number of events to capture (0 = unlimited)")
traceCmd.Flags().StringVarP(&traceOutputFormat, "output", "o", "human", "Output format (human, json)") traceCmd.Flags().StringVarP(&traceOutputFormat, "output", "o", "human", "Output format (human, json)")
traceCmd.Flags().BoolVarP(&traceNumeric, "numeric", "n", false, "Don't resolve hostnames")
traceCmd.Flags().BoolVar(&traceTimestamp, "ts", false, "Include timestamp in output") traceCmd.Flags().BoolVar(&traceTimestamp, "ts", false, "Include timestamp in output")
// shared filter flags // shared flags
addFilterFlags(traceCmd) addFilterFlags(traceCmd)
addResolutionFlags(traceCmd)
} }

View File

@@ -18,6 +18,7 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/errutil"
"github.com/karol-broda/snitch/internal/tui" "github.com/karol-broda/snitch/internal/tui"
) )
@@ -93,13 +94,13 @@ func runUpgrade(cmd *cobra.Command, args []string) error {
if currentClean == latestClean { if currentClean == latestClean {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Println(tui.SymbolSuccess + " you are running the latest version") errutil.Println(green, tui.SymbolSuccess+" you are running the latest version")
return nil return nil
} }
if current == "dev" { if current == "dev" {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " you are running a development build") errutil.Println(yellow, tui.SymbolWarning+" you are running a development build")
fmt.Println() fmt.Println()
fmt.Println("use one of the methods below to install a release version:") fmt.Println("use one of the methods below to install a release version:")
fmt.Println() fmt.Println()
@@ -108,7 +109,7 @@ func runUpgrade(cmd *cobra.Command, args []string) error {
} }
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", current, latest) errutil.Printf(green, tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", current, latest)
fmt.Println() fmt.Println()
if !upgradeYes { if !upgradeYes {
@@ -116,8 +117,8 @@ func runUpgrade(cmd *cobra.Command, args []string) error {
fmt.Println() fmt.Println()
faint := color.New(color.Faint) faint := color.New(color.Faint)
cmdStyle := color.New(color.FgCyan) cmdStyle := color.New(color.FgCyan)
faint.Print(" in-place ") errutil.Print(faint, " in-place ")
cmdStyle.Println("snitch upgrade --yes") errutil.Println(cmdStyle, "snitch upgrade --yes")
return nil return nil
} }
@@ -134,17 +135,17 @@ func handleSpecificVersion(current, target string) error {
if isVersionLower(targetClean, firstUpgradeVersion) { if isVersionLower(targetClean, firstUpgradeVersion) {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Printf(tui.SymbolWarning+" warning: the upgrade command was introduced in v%s\n", firstUpgradeVersion) errutil.Printf(yellow, tui.SymbolWarning+" warning: the upgrade command was introduced in v%s\n", firstUpgradeVersion)
faint := color.New(color.Faint) faint := color.New(color.Faint)
faint.Printf(" version %s does not include this command\n", target) errutil.Printf(faint, " version %s does not include this command\n", target)
faint.Println(" you will need to use other methods to upgrade from that version") errutil.Println(faint, " you will need to use other methods to upgrade from that version")
fmt.Println() fmt.Println()
} }
currentClean := strings.TrimPrefix(current, "v") currentClean := strings.TrimPrefix(current, "v")
if currentClean == targetClean { if currentClean == targetClean {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Println(tui.SymbolSuccess + " you are already running this version") errutil.Println(green, tui.SymbolSuccess+" you are already running this version")
return nil return nil
} }
@@ -153,15 +154,15 @@ func handleSpecificVersion(current, target string) error {
cmdStyle := color.New(color.FgCyan) cmdStyle := color.New(color.FgCyan)
if isVersionLower(targetClean, currentClean) { if isVersionLower(targetClean, currentClean) {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Printf(tui.SymbolArrowDown+" this will downgrade from %s to %s\n", current, target) errutil.Printf(yellow, tui.SymbolArrowDown+" this will downgrade from %s to %s\n", current, target)
} else { } else {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Printf(tui.SymbolArrowUp+" this will upgrade from %s to %s\n", current, target) errutil.Printf(green, tui.SymbolArrowUp+" this will upgrade from %s to %s\n", current, target)
} }
fmt.Println() fmt.Println()
faint.Print("run ") errutil.Print(faint, "run ")
cmdStyle.Printf("snitch upgrade --version %s --yes", target) errutil.Printf(cmdStyle, "snitch upgrade --version %s --yes", target)
faint.Println(" to proceed") errutil.Println(faint, " to proceed")
return nil return nil
} }
@@ -175,20 +176,20 @@ func handleNixUpgrade(current, latest string) error {
currentCommit := extractCommitFromVersion(current) currentCommit := extractCommitFromVersion(current)
dirty := isNixDirty(current) dirty := isNixDirty(current)
faint.Print("current ") errutil.Print(faint, "current ")
version.Print(current) errutil.Print(version, current)
if currentCommit != "" { if currentCommit != "" {
faint.Printf(" (commit %s)", currentCommit) errutil.Printf(faint, " (commit %s)", currentCommit)
} }
fmt.Println() fmt.Println()
faint.Print("latest ") errutil.Print(faint, "latest ")
version.Println(latest) errutil.Println(version, latest)
fmt.Println() fmt.Println()
if dirty { if dirty {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " you are running a dirty nix build (uncommitted changes)") errutil.Println(yellow, tui.SymbolWarning+" you are running a dirty nix build (uncommitted changes)")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -196,8 +197,8 @@ func handleNixUpgrade(current, latest string) error {
if currentCommit == "" { if currentCommit == "" {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -205,11 +206,11 @@ func handleNixUpgrade(current, latest string) error {
releaseCommit, err := fetchCommitForTag(latest) releaseCommit, err := fetchCommitForTag(latest)
if err != nil { if err != nil {
faint.Printf(" (could not fetch release commit: %v)\n", err) errutil.Printf(faint, " (could not fetch release commit: %v)\n", err)
fmt.Println() fmt.Println()
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -222,20 +223,20 @@ func handleNixUpgrade(current, latest string) error {
if strings.HasPrefix(releaseCommit, currentCommit) || strings.HasPrefix(currentCommit, releaseShort) { if strings.HasPrefix(releaseCommit, currentCommit) || strings.HasPrefix(currentCommit, releaseShort) {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Printf(tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort) errutil.Printf(green, tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort)
return nil return nil
} }
comparison, err := compareCommits(latest, currentCommit) comparison, err := compareCommits(latest, currentCommit)
if err != nil { if err != nil {
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", currentCommit, latest) errutil.Printf(green, tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", currentCommit, latest)
faint.Printf(" your commit: %s\n", currentCommit) errutil.Printf(faint, " your commit: %s\n", currentCommit)
faint.Printf(" release: %s (%s)\n", releaseShort, latest) errutil.Printf(faint, " release: %s (%s)\n", releaseShort, latest)
fmt.Println() fmt.Println()
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -243,30 +244,30 @@ func handleNixUpgrade(current, latest string) error {
if comparison.AheadBy > 0 { if comparison.AheadBy > 0 {
cyan := color.New(color.FgCyan) cyan := color.New(color.FgCyan)
cyan.Printf(tui.SymbolArrowUp+" you are %d commit(s) ahead of %s\n", comparison.AheadBy, latest) errutil.Printf(cyan, tui.SymbolArrowUp+" you are %d commit(s) ahead of %s\n", comparison.AheadBy, latest)
faint.Printf(" your commit: %s\n", currentCommit) errutil.Printf(faint, " your commit: %s\n", currentCommit)
faint.Printf(" release: %s (%s)\n", releaseShort, latest) errutil.Printf(faint, " release: %s (%s)\n", releaseShort, latest)
fmt.Println() fmt.Println()
faint.Println("you are running a newer build than the latest release") errutil.Println(faint, "you are running a newer build than the latest release")
return nil return nil
} }
if comparison.BehindBy > 0 { if comparison.BehindBy > 0 {
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess+" update available: %d commit(s) behind %s\n", comparison.BehindBy, latest) errutil.Printf(green, tui.SymbolSuccess+" update available: %d commit(s) behind %s\n", comparison.BehindBy, latest)
faint.Printf(" your commit: %s\n", currentCommit) errutil.Printf(faint, " your commit: %s\n", currentCommit)
faint.Printf(" release: %s (%s)\n", releaseShort, latest) errutil.Printf(faint, " release: %s (%s)\n", releaseShort, latest)
fmt.Println() fmt.Println()
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
} }
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Printf(tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort) errutil.Printf(green, tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort)
return nil return nil
} }
@@ -278,22 +279,22 @@ func handleNixSpecificVersion(current, target string) error {
printVersionComparisonTarget(current, target) printVersionComparisonTarget(current, target)
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint := color.New(color.Faint) faint := color.New(color.Faint)
faint.Println(" nix store is immutable; in-place upgrades are not supported") errutil.Println(faint, " nix store is immutable; in-place upgrades are not supported")
fmt.Println() fmt.Println()
bold := color.New(color.Bold) bold := color.New(color.Bold)
cmd := color.New(color.FgCyan) cmd := color.New(color.FgCyan)
bold.Println("to install a specific version with nix:") errutil.Println(bold, "to install a specific version with nix:")
fmt.Println() fmt.Println()
faint.Print(" specific ref ") errutil.Print(faint, " specific ref ")
cmd.Printf("nix profile install github:%s/%s/%s\n", repoOwner, repoName, target) errutil.Printf(cmd, "nix profile install github:%s/%s/%s\n", repoOwner, repoName, target)
faint.Print(" latest ") errutil.Print(faint, " latest ")
cmd.Printf("nix profile install github:%s/%s\n", repoOwner, repoName) errutil.Printf(cmd, "nix profile install github:%s/%s\n", repoOwner, repoName)
return nil return nil
} }
@@ -333,7 +334,7 @@ func fetchLatestVersion() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github api returned status %d", resp.StatusCode) return "", fmt.Errorf("github api returned status %d", resp.StatusCode)
@@ -355,10 +356,10 @@ func printVersionComparison(current, latest string) {
faint := color.New(color.Faint) faint := color.New(color.Faint)
version := color.New(color.FgCyan) version := color.New(color.FgCyan)
faint.Print("current ") errutil.Print(faint, "current ")
version.Println(current) errutil.Println(version, current)
faint.Print("latest ") errutil.Print(faint, "latest ")
version.Println(latest) errutil.Println(version, latest)
fmt.Println() fmt.Println()
} }
@@ -366,10 +367,10 @@ func printVersionComparisonTarget(current, target string) {
faint := color.New(color.Faint) faint := color.New(color.Faint)
version := color.New(color.FgCyan) version := color.New(color.FgCyan)
faint.Print("current ") errutil.Print(faint, "current ")
version.Println(current) errutil.Println(version, current)
faint.Print("target ") errutil.Print(faint, "target ")
version.Println(target) errutil.Println(version, target)
fmt.Println() fmt.Println()
} }
@@ -378,20 +379,20 @@ func printUpgradeInstructions() {
faint := color.New(color.Faint) faint := color.New(color.Faint)
cmd := color.New(color.FgCyan) cmd := color.New(color.FgCyan)
bold.Println("upgrade options:") errutil.Println(bold, "upgrade options:")
fmt.Println() fmt.Println()
faint.Print(" go install ") errutil.Print(faint, " go install ")
cmd.Printf("go install github.com/%s/%s@latest\n", repoOwner, repoName) errutil.Printf(cmd, "go install github.com/%s/%s@latest\n", repoOwner, repoName)
faint.Print(" shell script ") errutil.Print(faint, " shell script ")
cmd.Printf("curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | sh\n", repoOwner, repoName) errutil.Printf(cmd, "curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | sh\n", repoOwner, repoName)
faint.Print(" arch (aur) ") errutil.Print(faint, " arch (aur) ")
cmd.Println("yay -S snitch-bin") errutil.Println(cmd, "yay -S snitch-bin")
faint.Print(" nix ") errutil.Print(faint, " nix ")
cmd.Printf("nix profile upgrade --inputs-from github:%s/%s\n", repoOwner, repoName) errutil.Printf(cmd, "nix profile upgrade --inputs-from github:%s/%s\n", repoOwner, repoName)
} }
func performUpgrade(version string) error { func performUpgrade(version string) error {
@@ -407,7 +408,7 @@ func performUpgrade(version string) error {
if strings.HasPrefix(execPath, "/nix/store/") { if strings.HasPrefix(execPath, "/nix/store/") {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " cannot perform in-place upgrade for nix installation") errutil.Println(yellow, tui.SymbolWarning+" cannot perform in-place upgrade for nix installation")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -423,15 +424,15 @@ func performUpgrade(version string) error {
faint := color.New(color.Faint) faint := color.New(color.Faint)
cyan := color.New(color.FgCyan) cyan := color.New(color.FgCyan)
faint.Print(tui.SymbolDownload + " downloading ") errutil.Print(faint, tui.SymbolDownload+" downloading ")
cyan.Printf("%s", archiveName) errutil.Printf(cyan, "%s", archiveName)
faint.Println("...") errutil.Println(faint, "...")
resp, err := http.Get(downloadURL) resp, err := http.Get(downloadURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to download: %w", err) return fmt.Errorf("failed to download: %w", err)
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status %d", resp.StatusCode) return fmt.Errorf("download failed with status %d", resp.StatusCode)
@@ -441,7 +442,7 @@ func performUpgrade(version string) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to create temp directory: %w", err) return fmt.Errorf("failed to create temp directory: %w", err)
} }
defer os.RemoveAll(tmpDir) defer errutil.RemoveAll(tmpDir)
binaryPath, err := extractBinaryFromTarGz(resp.Body, tmpDir) binaryPath, err := extractBinaryFromTarGz(resp.Body, tmpDir)
if err != nil { if err != nil {
@@ -458,14 +459,14 @@ func performUpgrade(version string) error {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
cmdStyle := color.New(color.FgCyan) cmdStyle := color.New(color.FgCyan)
yellow.Printf(tui.SymbolWarning+" elevated permissions required to install to %s\n", targetDir) errutil.Printf(yellow, tui.SymbolWarning+" elevated permissions required to install to %s\n", targetDir)
fmt.Println() fmt.Println()
faint.Println("run with sudo or install to a user-writable location:") errutil.Println(faint, "run with sudo or install to a user-writable location:")
fmt.Println() fmt.Println()
faint.Print(" sudo ") errutil.Print(faint, " sudo ")
cmdStyle.Println("sudo snitch upgrade --yes") errutil.Println(cmdStyle, "sudo snitch upgrade --yes")
faint.Print(" custom dir ") errutil.Print(faint, " custom dir ")
cmdStyle.Printf("curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | INSTALL_DIR=~/.local/bin sh\n", errutil.Printf(cmdStyle, "curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | INSTALL_DIR=~/.local/bin sh\n",
repoOwner, repoName) repoOwner, repoName)
return nil return nil
} }
@@ -491,11 +492,11 @@ func performUpgrade(version string) error {
if err := os.Remove(backupPath); err != nil { if err := os.Remove(backupPath); err != nil {
// non-fatal, just warn // non-fatal, just warn
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Fprintf(os.Stderr, tui.SymbolWarning + " warning: failed to remove backup file %s: %v\n", backupPath, err) errutil.Fprintf(yellow, os.Stderr, tui.SymbolWarning+" warning: failed to remove backup file %s: %v\n", backupPath, err)
} }
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess + " successfully upgraded to %s\n", version) errutil.Printf(green, tui.SymbolSuccess+" successfully upgraded to %s\n", version)
return nil return nil
} }
@@ -504,7 +505,7 @@ func extractBinaryFromTarGz(r io.Reader, destDir string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer gzr.Close() defer errutil.Close(gzr)
tr := tar.NewReader(gzr) tr := tar.NewReader(gzr)
@@ -534,10 +535,10 @@ func extractBinaryFromTarGz(r io.Reader, destDir string) (string, error) {
} }
if _, err := io.Copy(outFile, tr); err != nil { if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close() errutil.Close(outFile)
return "", err return "", err
} }
outFile.Close() errutil.Close(outFile)
return destPath, nil return destPath, nil
} }
@@ -551,8 +552,8 @@ func isWritable(path string) bool {
if err != nil { if err != nil {
return false return false
} }
f.Close() errutil.Close(f)
os.Remove(testFile) errutil.Remove(testFile)
return true return true
} }
@@ -561,13 +562,13 @@ func copyFile(src, dst string) error {
if err != nil { if err != nil {
return err return err
} }
defer srcFile.Close() defer errutil.Close(srcFile)
dstFile, err := os.Create(dst) dstFile, err := os.Create(dst)
if err != nil { if err != nil {
return err return err
} }
defer dstFile.Close() defer errutil.Close(dstFile)
if _, err := io.Copy(dstFile, srcFile); err != nil { if _, err := io.Copy(dstFile, srcFile); err != nil {
return err return err
@@ -580,7 +581,7 @@ func removeQuarantine(path string) {
cmd := exec.Command("xattr", "-d", "com.apple.quarantine", path) cmd := exec.Command("xattr", "-d", "com.apple.quarantine", path)
if err := cmd.Run(); err == nil { if err := cmd.Run(); err == nil {
faint := color.New(color.Faint) faint := color.New(color.Faint)
faint.Println(" removed macOS quarantine attribute") errutil.Println(faint, " removed macOS quarantine attribute")
} }
} }
@@ -633,7 +634,7 @@ func fetchCommitForTag(tag string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github api returned status %d", resp.StatusCode) return "", fmt.Errorf("github api returned status %d", resp.StatusCode)
@@ -654,7 +655,7 @@ func compareCommits(base, head string) (*githubCompare, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("github api returned status %d", resp.StatusCode) return nil, fmt.Errorf("github api returned status %d", resp.StatusCode)
@@ -673,16 +674,16 @@ func printNixUpgradeInstructions() {
faint := color.New(color.Faint) faint := color.New(color.Faint)
cmd := color.New(color.FgCyan) cmd := color.New(color.FgCyan)
bold.Println("nix upgrade options:") errutil.Println(bold, "nix upgrade options:")
fmt.Println() fmt.Println()
faint.Print(" flake profile ") errutil.Print(faint, " flake profile ")
cmd.Printf("nix profile install github:%s/%s\n", repoOwner, repoName) errutil.Printf(cmd, "nix profile install github:%s/%s\n", repoOwner, repoName)
faint.Print(" flake update ") errutil.Print(faint, " flake update ")
cmd.Println("nix flake update snitch (in your system/home-manager config)") errutil.Println(cmd, "nix flake update snitch (in your system/home-manager config)")
faint.Print(" rebuild ") errutil.Print(faint, " rebuild ")
cmd.Println("nixos-rebuild switch or home-manager switch") errutil.Println(cmd, "nixos-rebuild switch or home-manager switch")
} }

View File

@@ -6,6 +6,8 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/errutil"
) )
var ( var (
@@ -22,20 +24,20 @@ var versionCmd = &cobra.Command{
cyan := color.New(color.FgCyan) cyan := color.New(color.FgCyan)
faint := color.New(color.Faint) faint := color.New(color.Faint)
bold.Print("snitch ") errutil.Print(bold, "snitch ")
cyan.Println(Version) errutil.Println(cyan, Version)
fmt.Println() fmt.Println()
faint.Print(" commit ") errutil.Print(faint, " commit ")
fmt.Println(Commit) fmt.Println(Commit)
faint.Print(" built ") errutil.Print(faint, " built ")
fmt.Println(Date) fmt.Println(Date)
faint.Print(" go ") errutil.Print(faint, " go ")
fmt.Println(runtime.Version()) fmt.Println(runtime.Version())
faint.Print(" os ") errutil.Print(faint, " os ")
fmt.Printf("%s/%s\n", runtime.GOOS, runtime.GOARCH) fmt.Printf("%s/%s\n", runtime.GOOS, runtime.GOARCH)
}, },
} }

View File

@@ -106,5 +106,32 @@
overlays.default = final: _prev: { overlays.default = final: _prev: {
snitch = mkSnitch final; snitch = mkSnitch final;
}; };
homeManagerModules.default = import ./nix/hm-module.nix;
homeManagerModules.snitch = self.homeManagerModules.default;
# alias for flake-parts compatibility
homeModules.default = self.homeManagerModules.default;
homeModules.snitch = self.homeManagerModules.default;
checks = eachSystem (system:
let
pkgs = import nixpkgs {
inherit system;
overlays = [ self.overlays.default ];
};
in
{
# home manager module tests
hm-module = import ./nix/tests/hm-module-test.nix {
inherit pkgs;
lib = pkgs.lib;
hmModule = self.homeManagerModules.default;
};
# package builds correctly
package = self.packages.${system}.default;
}
);
}; };
} }

12
go.mod
View File

@@ -1,23 +1,26 @@
module github.com/karol-broda/snitch module github.com/karol-broda/snitch
go 1.24.0 go 1.25.0
require ( require (
github.com/charmbracelet/bubbletea v1.3.6 github.com/charmbracelet/bubbletea v1.3.6
github.com/charmbracelet/lipgloss v1.1.0
github.com/charmbracelet/x/exp/teatest v0.0.0-20251215102626-e0db08df7383
github.com/fatih/color v1.18.0 github.com/fatih/color v1.18.0
github.com/mattn/go-runewidth v0.0.16
github.com/spf13/cobra v1.9.1 github.com/spf13/cobra v1.9.1
github.com/spf13/viper v1.19.0
github.com/tidwall/pretty v1.2.1 github.com/tidwall/pretty v1.2.1
golang.org/x/term v0.38.0
) )
require ( require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymanbagabas/go-udiff v0.3.1 // indirect github.com/aymanbagabas/go-udiff v0.3.1 // indirect
github.com/charmbracelet/colorprofile v0.3.2 // indirect github.com/charmbracelet/colorprofile v0.3.2 // indirect
github.com/charmbracelet/lipgloss v1.1.0 // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/exp/golden v0.0.0-20240806155701-69247e0abc2a // indirect github.com/charmbracelet/x/exp/golden v0.0.0-20240806155701-69247e0abc2a // indirect
github.com/charmbracelet/x/exp/teatest v0.0.0-20251215102626-e0db08df7383 // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect
@@ -28,7 +31,6 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect
@@ -41,7 +43,6 @@ require (
github.com/spf13/afero v1.11.0 // indirect github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect github.com/spf13/pflag v1.0.6 // indirect
github.com/spf13/viper v1.19.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
@@ -49,7 +50,6 @@ require (
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/sync v0.16.0 // indirect golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.39.0 // indirect golang.org/x/sys v0.39.0 // indirect
golang.org/x/term v0.38.0 // indirect
golang.org/x/text v0.28.0 // indirect golang.org/x/text v0.28.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

29
go.sum
View File

@@ -4,14 +4,10 @@ github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3v
github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E=
github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU= github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU=
github.com/charmbracelet/bubbletea v1.3.6/go.mod h1:oQD9VCRQFF8KplacJLo28/jofOI2ToOfGYeFgBBxHOc= github.com/charmbracelet/bubbletea v1.3.6/go.mod h1:oQD9VCRQFF8KplacJLo28/jofOI2ToOfGYeFgBBxHOc=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/colorprofile v0.3.2 h1:9J27WdztfJQVAQKX2WOlSSRB+5gaKqqITmrvb1uTIiI= github.com/charmbracelet/colorprofile v0.3.2 h1:9J27WdztfJQVAQKX2WOlSSRB+5gaKqqITmrvb1uTIiI=
github.com/charmbracelet/colorprofile v0.3.2/go.mod h1:mTD5XzNeWHj8oqHb+S1bssQb7vIHbepiebQ2kPKVKbI= github.com/charmbracelet/colorprofile v0.3.2/go.mod h1:mTD5XzNeWHj8oqHb+S1bssQb7vIHbepiebQ2kPKVKbI=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0=
github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
@@ -25,16 +21,26 @@ github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNE
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
@@ -59,9 +65,13 @@ github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
@@ -87,6 +97,7 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
@@ -98,28 +109,22 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -37,6 +37,19 @@ static const char* get_username(int uid) {
return pw->pw_name; return pw->pw_name;
} }
// get current working directory for a process
static int get_proc_cwd(int pid, char *path, int pathlen) {
struct proc_vnodepathinfo vpi;
int ret = proc_pidinfo(pid, PROC_PIDVNODEPATHINFO, 0, &vpi, sizeof(vpi));
if (ret <= 0) {
path[0] = '\0';
return -1;
}
strncpy(path, vpi.pvi_cdir.vip_path, pathlen - 1);
path[pathlen - 1] = '\0';
return 0;
}
// socket info extraction - handles the union properly in C // socket info extraction - handles the union properly in C
typedef struct { typedef struct {
int family; int family;
@@ -164,6 +177,7 @@ func listAllPids() ([]int, error) {
func getConnectionsForPid(pid int) ([]Connection, error) { func getConnectionsForPid(pid int) ([]Connection, error) {
procName := getProcessName(pid) procName := getProcessName(pid)
cwd := getProcessCwd(pid)
uid := int(C.get_proc_uid(C.int(pid))) uid := int(C.get_proc_uid(C.int(pid)))
user := "" user := ""
if uid >= 0 { if uid >= 0 {
@@ -198,7 +212,7 @@ func getConnectionsForPid(pid int) ([]Connection, error) {
continue continue
} }
conn, ok := getSocketInfo(pid, int(fdInfo.proc_fd), procName, uid, user) conn, ok := getSocketInfo(pid, int(fdInfo.proc_fd), procName, cwd, uid, user)
if ok { if ok {
connections = append(connections, conn) connections = append(connections, conn)
} }
@@ -207,7 +221,7 @@ func getConnectionsForPid(pid int) ([]Connection, error) {
return connections, nil return connections, nil
} }
func getSocketInfo(pid, fd int, procName string, uid int, user string) (Connection, bool) { func getSocketInfo(pid, fd int, procName, cwd string, uid int, user string) (Connection, bool) {
var info C.socket_info_t var info C.socket_info_t
ret := C.get_socket_info(C.int(pid), C.int(fd), &info) ret := C.get_socket_info(C.int(pid), C.int(fd), &info)
@@ -276,6 +290,7 @@ func getSocketInfo(pid, fd int, procName string, uid int, user string) (Connecti
Rport: int(info.rport), Rport: int(info.rport),
PID: pid, PID: pid,
Process: procName, Process: procName,
Cwd: cwd,
UID: uid, UID: uid,
User: user, User: user,
Interface: guessNetworkInterface(laddr), Interface: guessNetworkInterface(laddr),
@@ -293,6 +308,15 @@ func getProcessName(pid int) string {
return C.GoString(&name[0]) return C.GoString(&name[0])
} }
func getProcessCwd(pid int) string {
var path [1024]C.char
ret := C.get_proc_cwd(C.int(pid), &path[0], 1024)
if ret != 0 {
return ""
}
return C.GoString(&path[0])
}
func ipv4ToString(addr uint32) string { func ipv4ToString(addr uint32) string {
ip := make(net.IP, 4) ip := make(net.IP, 4)
ip[0] = byte(addr) ip[0] = byte(addr)

View File

@@ -11,21 +11,78 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/karol-broda/snitch/internal/errutil"
) )
// set SNITCH_DEBUG_TIMING=1 to enable timing diagnostics
var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != ""
func logTiming(label string, start time.Time, extra ...string) {
if !debugTiming {
return
}
elapsed := time.Since(start)
if len(extra) > 0 {
fmt.Fprintf(os.Stderr, "[timing] %s: %v (%s)\n", label, elapsed, extra[0])
} else {
fmt.Fprintf(os.Stderr, "[timing] %s: %v\n", label, elapsed)
}
}
// userCache caches uid to username mappings to avoid repeated lookups
var userCache = struct {
sync.RWMutex
m map[int]string
}{m: make(map[int]string)}
func lookupUsername(uid int) string {
userCache.RLock()
if username, exists := userCache.m[uid]; exists {
userCache.RUnlock()
return username
}
userCache.RUnlock()
start := time.Now()
username := strconv.Itoa(uid)
u, err := user.LookupId(strconv.Itoa(uid))
if err == nil && u != nil {
username = u.Username
}
elapsed := time.Since(start)
if debugTiming && elapsed > 10*time.Millisecond {
fmt.Fprintf(os.Stderr, "[timing] user.LookupId(%d) slow: %v\n", uid, elapsed)
}
userCache.Lock()
userCache.m[uid] = username
userCache.Unlock()
return username
}
// DefaultCollector implements the Collector interface using /proc filesystem // DefaultCollector implements the Collector interface using /proc filesystem
type DefaultCollector struct{} type DefaultCollector struct{}
// GetConnections fetches all network connections by parsing /proc files // GetConnections fetches all network connections by parsing /proc files
func (dc *DefaultCollector) GetConnections() ([]Connection, error) { func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
totalStart := time.Now()
defer func() { logTiming("GetConnections total", totalStart) }()
inodeStart := time.Now()
inodeMap, err := buildInodeToProcessMap() inodeMap, err := buildInodeToProcessMap()
logTiming("buildInodeToProcessMap", inodeStart, fmt.Sprintf("%d inodes", len(inodeMap)))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to build inode map: %w", err) return nil, fmt.Errorf("failed to build inode map: %w", err)
} }
var connections []Connection var connections []Connection
parseStart := time.Now()
tcpConns, err := parseProcNet("/proc/net/tcp", "tcp", 4, inodeMap) tcpConns, err := parseProcNet("/proc/net/tcp", "tcp", 4, inodeMap)
if err == nil { if err == nil {
connections = append(connections, tcpConns...) connections = append(connections, tcpConns...)
@@ -45,6 +102,7 @@ func (dc *DefaultCollector) GetConnections() ([]Connection, error) {
if err == nil { if err == nil {
connections = append(connections, udpConns6...) connections = append(connections, udpConns6...)
} }
logTiming("parseProcNet (all)", parseStart, fmt.Sprintf("%d connections", len(connections)))
return connections, nil return connections, nil
} }
@@ -67,102 +125,175 @@ func GetAllConnections() ([]Connection, error) {
type processInfo struct { type processInfo struct {
pid int pid int
command string command string
cmdline string
cwd string
uid int uid int
user string user string
} }
func buildInodeToProcessMap() (map[int64]*processInfo, error) { type inodeEntry struct {
inodeMap := make(map[int64]*processInfo) inode int64
info *processInfo
}
func buildInodeToProcessMap() (map[int64]*processInfo, error) {
readDirStart := time.Now()
procDir, err := os.Open("/proc") procDir, err := os.Open("/proc")
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer procDir.Close() defer errutil.Close(procDir)
entries, err := procDir.Readdir(-1) entries, err := procDir.Readdir(-1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// collect pids first
pids := make([]int, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {
if !entry.IsDir() { if !entry.IsDir() {
continue continue
} }
pid, err := strconv.Atoi(entry.Name())
pidStr := entry.Name()
pid, err := strconv.Atoi(pidStr)
if err != nil { if err != nil {
continue continue
} }
pids = append(pids, pid)
}
logTiming(" readdir /proc", readDirStart, fmt.Sprintf("%d pids", len(pids)))
procInfo, err := getProcessInfo(pid) // process pids in parallel with limited concurrency
if err != nil { scanStart := time.Now()
continue const numWorkers = 8
} pidChan := make(chan int, len(pids))
resultChan := make(chan []inodeEntry, len(pids))
fdDir := filepath.Join("/proc", pidStr, "fd") var totalFDs atomic.Int64
fdEntries, err := os.ReadDir(fdDir) var wg sync.WaitGroup
if err != nil { for i := 0; i < numWorkers; i++ {
continue wg.Add(1)
} go func() {
defer wg.Done()
for _, fdEntry := range fdEntries { for pid := range pidChan {
fdPath := filepath.Join(fdDir, fdEntry.Name()) entries := scanProcessSockets(pid)
link, err := os.Readlink(fdPath) if len(entries) > 0 {
if err != nil { totalFDs.Add(int64(len(entries)))
continue resultChan <- entries
}
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
inodeStr := link[8 : len(link)-1]
inode, err := strconv.ParseInt(inodeStr, 10, 64)
if err != nil {
continue
} }
inodeMap[inode] = procInfo
} }
}()
}
for _, pid := range pids {
pidChan <- pid
}
close(pidChan)
go func() {
wg.Wait()
close(resultChan)
}()
inodeMap := make(map[int64]*processInfo)
for entries := range resultChan {
for _, e := range entries {
inodeMap[e.inode] = e.info
} }
} }
logTiming(" scan all processes", scanStart, fmt.Sprintf("%d socket fds scanned", totalFDs.Load()))
return inodeMap, nil return inodeMap, nil
} }
func scanProcessSockets(pid int) []inodeEntry {
start := time.Now()
procInfo, err := getProcessInfo(pid)
if err != nil {
return nil
}
pidStr := strconv.Itoa(pid)
fdDir := filepath.Join("/proc", pidStr, "fd")
fdEntries, err := os.ReadDir(fdDir)
if err != nil {
return nil
}
var results []inodeEntry
for _, fdEntry := range fdEntries {
fdPath := filepath.Join(fdDir, fdEntry.Name())
link, err := os.Readlink(fdPath)
if err != nil {
continue
}
if strings.HasPrefix(link, "socket:[") && strings.HasSuffix(link, "]") {
inodeStr := link[8 : len(link)-1]
inode, err := strconv.ParseInt(inodeStr, 10, 64)
if err != nil {
continue
}
results = append(results, inodeEntry{inode: inode, info: procInfo})
}
}
elapsed := time.Since(start)
if debugTiming && elapsed > 20*time.Millisecond {
fmt.Fprintf(os.Stderr, "[timing] slow process scan: pid=%d (%s) fds=%d time=%v\n",
pid, procInfo.command, len(fdEntries), elapsed)
}
return results
}
func getProcessInfo(pid int) (*processInfo, error) { func getProcessInfo(pid int) (*processInfo, error) {
info := &processInfo{pid: pid} info := &processInfo{pid: pid}
pidStr := strconv.Itoa(pid)
commPath := filepath.Join("/proc", strconv.Itoa(pid), "comm") commPath := filepath.Join("/proc", pidStr, "comm")
commData, err := os.ReadFile(commPath) commData, err := os.ReadFile(commPath)
if err == nil && len(commData) > 0 { if err == nil && len(commData) > 0 {
info.command = strings.TrimSpace(string(commData)) info.command = strings.TrimSpace(string(commData))
} }
if info.command == "" { cmdlinePath := filepath.Join("/proc", pidStr, "cmdline")
cmdlinePath := filepath.Join("/proc", strconv.Itoa(pid), "cmdline") cmdlineData, err := os.ReadFile(cmdlinePath)
cmdlineData, err := os.ReadFile(cmdlinePath) if err == nil && len(cmdlineData) > 0 {
if err != nil { parts := bytes.Split(cmdlineData, []byte{0})
return nil, err var args []string
} for _, p := range parts {
if len(p) > 0 {
if len(cmdlineData) > 0 { args = append(args, string(p))
parts := bytes.Split(cmdlineData, []byte{0})
if len(parts) > 0 && len(parts[0]) > 0 {
fullPath := string(parts[0])
baseName := filepath.Base(fullPath)
if strings.Contains(baseName, " ") {
baseName = strings.Fields(baseName)[0]
}
info.command = baseName
} }
} }
info.cmdline = strings.Join(args, " ")
if info.command == "" && len(parts) > 0 && len(parts[0]) > 0 {
fullPath := string(parts[0])
baseName := filepath.Base(fullPath)
if strings.Contains(baseName, " ") {
baseName = strings.Fields(baseName)[0]
}
info.command = baseName
}
} else if info.command == "" {
return nil, err
} }
statusPath := filepath.Join("/proc", strconv.Itoa(pid), "status") cwdPath := filepath.Join("/proc", pidStr, "cwd")
cwdLink, err := os.Readlink(cwdPath)
if err == nil {
info.cwd = cwdLink
}
statusPath := filepath.Join("/proc", pidStr, "status")
statusFile, err := os.Open(statusPath) statusFile, err := os.Open(statusPath)
if err != nil { if err != nil {
return info, nil return info, nil
} }
defer statusFile.Close() defer errutil.Close(statusFile)
scanner := bufio.NewScanner(statusFile) scanner := bufio.NewScanner(statusFile)
for scanner.Scan() { for scanner.Scan() {
@@ -173,12 +304,7 @@ func getProcessInfo(pid int) (*processInfo, error) {
uid, err := strconv.Atoi(fields[1]) uid, err := strconv.Atoi(fields[1])
if err == nil { if err == nil {
info.uid = uid info.uid = uid
u, err := user.LookupId(strconv.Itoa(uid)) info.user = lookupUsername(uid)
if err == nil {
info.user = u.Username
} else {
info.user = strconv.Itoa(uid)
}
} }
} }
break break
@@ -193,7 +319,7 @@ func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*process
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close() defer errutil.Close(file)
var connections []Connection var connections []Connection
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
@@ -248,6 +374,8 @@ func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*process
if procInfo, exists := inodeMap[inode]; exists { if procInfo, exists := inodeMap[inode]; exists {
conn.PID = procInfo.pid conn.PID = procInfo.pid
conn.Process = procInfo.command conn.Process = procInfo.command
conn.Cmdline = procInfo.cmdline
conn.Cwd = procInfo.cwd
conn.UID = procInfo.uid conn.UID = procInfo.uid
conn.User = procInfo.user conn.User = procInfo.user
} }
@@ -362,7 +490,7 @@ func GetUnixSockets() ([]Connection, error) {
if err != nil { if err != nil {
return connections, nil return connections, nil
} }
defer file.Close() defer errutil.Close(file)
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
scanner.Scan() scanner.Scan()

View File

@@ -1,7 +1,10 @@
//go:build linux
package collector package collector
import ( import (
"testing" "testing"
"time"
) )
func TestGetConnections(t *testing.T) { func TestGetConnections(t *testing.T) {
@@ -13,4 +16,158 @@ func TestGetConnections(t *testing.T) {
// connections are dynamic, so just verify function succeeded // connections are dynamic, so just verify function succeeded
t.Logf("Successfully got %d connections", len(conns)) t.Logf("Successfully got %d connections", len(conns))
}
func TestGetConnectionsPerformance(t *testing.T) {
// measures performance to catch regressions
// run with: go test -v -run TestGetConnectionsPerformance
const maxDuration = 500 * time.Millisecond
const iterations = 5
// warm up caches first
_, err := GetConnections()
if err != nil {
t.Fatalf("warmup failed: %v", err)
}
var total time.Duration
var maxSeen time.Duration
for i := 0; i < iterations; i++ {
start := time.Now()
conns, err := GetConnections()
elapsed := time.Since(start)
if err != nil {
t.Fatalf("iteration %d failed: %v", i, err)
}
total += elapsed
if elapsed > maxSeen {
maxSeen = elapsed
}
t.Logf("iteration %d: %v (%d connections)", i+1, elapsed, len(conns))
}
avg := total / time.Duration(iterations)
t.Logf("average: %v, max: %v", avg, maxSeen)
if maxSeen > maxDuration {
t.Errorf("slowest iteration took %v, expected < %v", maxSeen, maxDuration)
}
}
func TestGetConnectionsColdCache(t *testing.T) {
// tests performance with cold user cache
// this simulates first run or after cache invalidation
const maxDuration = 2 * time.Second
clearUserCache()
start := time.Now()
conns, err := GetConnections()
elapsed := time.Since(start)
if err != nil {
t.Fatalf("GetConnections() failed: %v", err)
}
t.Logf("cold cache: %v (%d connections, %d cached users after)",
elapsed, len(conns), userCacheSize())
if elapsed > maxDuration {
t.Errorf("cold cache took %v, expected < %v", elapsed, maxDuration)
}
}
func BenchmarkGetConnections(b *testing.B) {
// warm cache benchmark - measures typical runtime
// run with: go test -bench=BenchmarkGetConnections -benchtime=5s
// warm up
_, _ = GetConnections()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = GetConnections()
}
}
func BenchmarkGetConnectionsColdCache(b *testing.B) {
// cold cache benchmark - measures worst-case with cache cleared each iteration
// run with: go test -bench=BenchmarkGetConnectionsColdCache -benchtime=10s
b.ResetTimer()
for i := 0; i < b.N; i++ {
clearUserCache()
_, _ = GetConnections()
}
}
func BenchmarkBuildInodeMap(b *testing.B) {
// benchmarks just the inode map building (most expensive part)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = buildInodeToProcessMap()
}
}
func TestConnectionHasCmdlineAndCwd(t *testing.T) {
conns, err := GetConnections()
if err != nil {
t.Fatalf("GetConnections() returned an error: %v", err)
}
if len(conns) == 0 {
t.Skip("no connections to test")
}
// find a connection with a PID (owned by some process)
var connWithProcess *Connection
for i := range conns {
if conns[i].PID > 0 {
connWithProcess = &conns[i]
break
}
}
if connWithProcess == nil {
t.Skip("no connections with associated process found")
}
t.Logf("testing connection: pid=%d process=%s", connWithProcess.PID, connWithProcess.Process)
// cmdline and cwd should be populated for connections with PIDs
// note: they might be empty if we don't have permission to read them
if connWithProcess.Cmdline != "" {
t.Logf("cmdline: %s", connWithProcess.Cmdline)
} else {
t.Logf("cmdline is empty (might be permission issue)")
}
if connWithProcess.Cwd != "" {
t.Logf("cwd: %s", connWithProcess.Cwd)
} else {
t.Logf("cwd is empty (might be permission issue)")
}
}
func TestGetProcessInfoPopulatesCmdlineAndCwd(t *testing.T) {
// test that getProcessInfo correctly populates cmdline and cwd for our own process
info, err := getProcessInfo(1) // init process (usually has cwd of /)
if err != nil {
t.Logf("could not get process info for pid 1: %v", err)
t.Skip("skipping - may not have permission")
}
t.Logf("pid 1 info: command=%s cmdline=%s cwd=%s", info.command, info.cmdline, info.cwd)
// at minimum, we should have a command name
if info.command == "" && info.cmdline == "" {
t.Error("expected either command or cmdline to be populated")
}
} }

View File

@@ -0,0 +1,18 @@
//go:build linux
package collector
// clearUserCache clears the user lookup cache for testing
func clearUserCache() {
userCache.Lock()
userCache.m = make(map[int]string)
userCache.Unlock()
}
// userCacheSize returns the number of cached user entries
func userCacheSize() int {
userCache.RLock()
defer userCache.RUnlock()
return len(userCache.m)
}

View File

@@ -128,3 +128,75 @@ func TestSortByTimestamp(t *testing.T) {
} }
} }
func TestSortByRemoteAddr(t *testing.T) {
conns := []Connection{
{Raddr: "192.168.1.100", Rport: 443},
{Raddr: "10.0.0.1", Rport: 80},
{Raddr: "172.16.0.50", Rport: 8080},
}
t.Run("sort by raddr ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRaddr, Direction: SortAsc})
if c[0].Raddr != "10.0.0.1" {
t.Errorf("expected '10.0.0.1' first, got '%s'", c[0].Raddr)
}
if c[1].Raddr != "172.16.0.50" {
t.Errorf("expected '172.16.0.50' second, got '%s'", c[1].Raddr)
}
if c[2].Raddr != "192.168.1.100" {
t.Errorf("expected '192.168.1.100' last, got '%s'", c[2].Raddr)
}
})
t.Run("sort by raddr descending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRaddr, Direction: SortDesc})
if c[0].Raddr != "192.168.1.100" {
t.Errorf("expected '192.168.1.100' first, got '%s'", c[0].Raddr)
}
})
}
func TestSortByRemotePort(t *testing.T) {
conns := []Connection{
{Raddr: "192.168.1.1", Rport: 443},
{Raddr: "192.168.1.2", Rport: 80},
{Raddr: "192.168.1.3", Rport: 8080},
}
t.Run("sort by rport ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRport, Direction: SortAsc})
if c[0].Rport != 80 {
t.Errorf("expected port 80 first, got %d", c[0].Rport)
}
if c[1].Rport != 443 {
t.Errorf("expected port 443 second, got %d", c[1].Rport)
}
if c[2].Rport != 8080 {
t.Errorf("expected port 8080 last, got %d", c[2].Rport)
}
})
t.Run("sort by rport descending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRport, Direction: SortDesc})
if c[0].Rport != 8080 {
t.Errorf("expected port 8080 first, got %d", c[0].Rport)
}
})
}

View File

@@ -6,6 +6,8 @@ type Connection struct {
TS time.Time `json:"ts"` TS time.Time `json:"ts"`
PID int `json:"pid"` PID int `json:"pid"`
Process string `json:"process"` Process string `json:"process"`
Cmdline string `json:"cmdline,omitempty"`
Cwd string `json:"cwd,omitempty"`
User string `json:"user"` User string `json:"user"`
UID int `json:"uid"` UID int `json:"uid"`
Proto string `json:"proto"` Proto string `json:"proto"`

View File

@@ -5,6 +5,8 @@ import (
"testing" "testing"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/karol-broda/snitch/internal/errutil"
) )
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
@@ -29,8 +31,8 @@ func TestInit(t *testing.T) {
origTerm := os.Getenv("TERM") origTerm := os.Getenv("TERM")
// Set test env vars // Set test env vars
os.Setenv("NO_COLOR", tc.noColor) errutil.Setenv("NO_COLOR", tc.noColor)
os.Setenv("TERM", tc.term) errutil.Setenv("TERM", tc.term)
Init(tc.mode) Init(tc.mode)
@@ -39,8 +41,8 @@ func TestInit(t *testing.T) {
} }
// Restore original env vars // Restore original env vars
os.Setenv("NO_COLOR", origNoColor) errutil.Setenv("NO_COLOR", origNoColor)
os.Setenv("TERM", origTerm) errutil.Setenv("TERM", origTerm)
}) })
} }
} }

View File

@@ -4,14 +4,22 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/karol-broda/snitch/internal/theme"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
// Config represents the application configuration // Config represents the application configuration
type Config struct { type Config struct {
Defaults DefaultConfig `mapstructure:"defaults"` Defaults DefaultConfig `mapstructure:"defaults"`
TUI TUIConfig `mapstructure:"tui"`
}
// TUIConfig contains TUI-specific configuration
type TUIConfig struct {
RememberState bool `mapstructure:"remember_state"`
} }
// DefaultConfig contains default values for CLI options // DefaultConfig contains default values for CLI options
@@ -23,6 +31,7 @@ type DefaultConfig struct {
Units string `mapstructure:"units"` Units string `mapstructure:"units"`
Color string `mapstructure:"color"` Color string `mapstructure:"color"`
Resolve bool `mapstructure:"resolve"` Resolve bool `mapstructure:"resolve"`
DNSCache bool `mapstructure:"dns_cache"`
IPv4 bool `mapstructure:"ipv4"` IPv4 bool `mapstructure:"ipv4"`
IPv6 bool `mapstructure:"ipv6"` IPv6 bool `mapstructure:"ipv6"`
NoHeaders bool `mapstructure:"no_headers"` NoHeaders bool `mapstructure:"no_headers"`
@@ -55,6 +64,7 @@ func Load() (*Config, error) {
// environment variable bindings for readme-documented variables // environment variable bindings for readme-documented variables
_ = v.BindEnv("config", "SNITCH_CONFIG") _ = v.BindEnv("config", "SNITCH_CONFIG")
_ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE") _ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE")
_ = v.BindEnv("defaults.dns_cache", "SNITCH_DNS_CACHE")
_ = v.BindEnv("defaults.theme", "SNITCH_THEME") _ = v.BindEnv("defaults.theme", "SNITCH_THEME")
_ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR") _ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR")
@@ -88,19 +98,22 @@ func Load() (*Config, error) {
} }
func setDefaults(v *viper.Viper) { func setDefaults(v *viper.Viper) {
// Set default values matching the README specification
v.SetDefault("defaults.interval", "1s") v.SetDefault("defaults.interval", "1s")
v.SetDefault("defaults.numeric", false) v.SetDefault("defaults.numeric", false)
v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}) v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"})
v.SetDefault("defaults.theme", "auto") v.SetDefault("defaults.theme", "ansi")
v.SetDefault("defaults.units", "auto") v.SetDefault("defaults.units", "auto")
v.SetDefault("defaults.color", "auto") v.SetDefault("defaults.color", "auto")
v.SetDefault("defaults.resolve", true) v.SetDefault("defaults.resolve", true)
v.SetDefault("defaults.dns_cache", true)
v.SetDefault("defaults.ipv4", false) v.SetDefault("defaults.ipv4", false)
v.SetDefault("defaults.ipv6", false) v.SetDefault("defaults.ipv6", false)
v.SetDefault("defaults.no_headers", false) v.SetDefault("defaults.no_headers", false)
v.SetDefault("defaults.output_format", "table") v.SetDefault("defaults.output_format", "table")
v.SetDefault("defaults.sort_by", "") v.SetDefault("defaults.sort_by", "")
// tui settings
v.SetDefault("tui.remember_state", false)
} }
func handleSpecialEnvVars(v *viper.Viper) { func handleSpecialEnvVars(v *viper.Viper) {
@@ -114,6 +127,11 @@ func handleSpecialEnvVars(v *viper.Viper) {
v.Set("defaults.resolve", false) v.Set("defaults.resolve", false)
v.Set("defaults.numeric", true) v.Set("defaults.numeric", true)
} }
// Handle SNITCH_DNS_CACHE - if set to "0", disable dns caching
if os.Getenv("SNITCH_DNS_CACHE") == "0" {
v.Set("defaults.dns_cache", false)
}
} }
// Get returns the global configuration, loading it if necessary // Get returns the global configuration, loading it if necessary
@@ -121,22 +139,25 @@ func Get() *Config {
if globalConfig == nil { if globalConfig == nil {
config, err := Load() config, err := Load()
if err != nil { if err != nil {
// Return default config on error
return &Config{ return &Config{
Defaults: DefaultConfig{ Defaults: DefaultConfig{
Interval: "1s", Interval: "1s",
Numeric: false, Numeric: false,
Fields: []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}, Fields: []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"},
Theme: "auto", Theme: "ansi",
Units: "auto", Units: "auto",
Color: "auto", Color: "auto",
Resolve: true, Resolve: true,
DNSCache: true,
IPv4: false, IPv4: false,
IPv6: false, IPv6: false,
NoHeaders: false, NoHeaders: false,
OutputFormat: "table", OutputFormat: "table",
SortBy: "", SortBy: "",
}, },
TUI: TUIConfig{
RememberState: false,
},
} }
} }
return config return config
@@ -154,7 +175,9 @@ func (c *Config) GetInterval() time.Duration {
// CreateExampleConfig creates an example configuration file // CreateExampleConfig creates an example configuration file
func CreateExampleConfig(path string) error { func CreateExampleConfig(path string) error {
exampleConfig := `# snitch configuration file themeList := strings.Join(theme.ListThemes(), ", ")
exampleConfig := fmt.Sprintf(`# snitch configuration file
# See https://github.com/you/snitch for full documentation # See https://github.com/you/snitch for full documentation
[defaults] [defaults]
@@ -167,8 +190,9 @@ numeric = false
# Default fields to display (comma-separated list) # Default fields to display (comma-separated list)
fields = ["pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"] fields = ["pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"]
# Default theme for TUI (dark, light, mono, auto) # Default theme for TUI (ansi inherits terminal colors)
theme = "auto" # Available: %s
theme = "%s"
# Default units for byte display (auto, si, iec) # Default units for byte display (auto, si, iec)
units = "auto" units = "auto"
@@ -187,17 +211,22 @@ ipv6 = false
no_headers = false no_headers = false
output_format = "table" output_format = "table"
sort_by = "" sort_by = ""
`
[tui]
# remember view options (filters, sort, resolution) between sessions
# state is saved to $XDG_STATE_HOME/snitch/tui.json
remember_state = false
`, themeList, theme.DefaultTheme)
// Ensure directory exists // Ensure directory exists
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err) return fmt.Errorf("failed to create config directory: %w", err)
} }
// Write config file // Write config file
if err := os.WriteFile(path, []byte(exampleConfig), 0644); err != nil { if err := os.WriteFile(path, []byte(exampleConfig), 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err) return fmt.Errorf("failed to write config file: %w", err)
} }
return nil return nil
} }

View File

@@ -0,0 +1,65 @@
package errutil
import (
"io"
"os"
"github.com/fatih/color"
)
func Ignore[T any](val T, _ error) T {
return val
}
func IgnoreErr(_ error) {}
func Close(c io.Closer) {
if c != nil {
_ = c.Close()
}
}
// color.Color wrappers - these discard the (int, error) return values
func Print(c *color.Color, a ...any) {
_, _ = c.Print(a...)
}
func Println(c *color.Color, a ...any) {
_, _ = c.Println(a...)
}
func Printf(c *color.Color, format string, a ...any) {
_, _ = c.Printf(format, a...)
}
func Fprintf(c *color.Color, w io.Writer, format string, a ...any) {
_, _ = c.Fprintf(w, format, a...)
}
// os function wrappers for test cleanup where errors are non-critical
func Setenv(key, value string) {
_ = os.Setenv(key, value)
}
func Unsetenv(key string) {
_ = os.Unsetenv(key)
}
func Remove(name string) {
_ = os.Remove(name)
}
func RemoveAll(path string) {
_ = os.RemoveAll(path)
}
// Flush calls Flush on a tabwriter and discards the error
type Flusher interface {
Flush() error
}
func Flush(f Flusher) {
_ = f.Flush()
}

View File

@@ -2,17 +2,22 @@ package resolver
import ( import (
"context" "context"
"fmt"
"net" "net"
"os"
"strconv" "strconv"
"sync" "sync"
"time" "time"
) )
var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != ""
// Resolver handles DNS and service name resolution with caching and timeouts // Resolver handles DNS and service name resolution with caching and timeouts
type Resolver struct { type Resolver struct {
timeout time.Duration timeout time.Duration
cache map[string]string cache map[string]string
mutex sync.RWMutex mutex sync.RWMutex
noCache bool
} }
// New creates a new resolver with the specified timeout // New creates a new resolver with the specified timeout
@@ -20,45 +25,60 @@ func New(timeout time.Duration) *Resolver {
return &Resolver{ return &Resolver{
timeout: timeout, timeout: timeout,
cache: make(map[string]string), cache: make(map[string]string),
noCache: false,
} }
} }
// SetNoCache disables caching - each lookup will hit DNS directly
func (r *Resolver) SetNoCache(noCache bool) {
r.noCache = noCache
}
// ResolveAddr resolves an IP address to a hostname, with caching // ResolveAddr resolves an IP address to a hostname, with caching
func (r *Resolver) ResolveAddr(addr string) string { func (r *Resolver) ResolveAddr(addr string) string {
// Check cache first // check cache first (unless caching is disabled)
r.mutex.RLock() if !r.noCache {
if cached, exists := r.cache[addr]; exists { r.mutex.RLock()
if cached, exists := r.cache[addr]; exists {
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock() r.mutex.RUnlock()
return cached
} }
r.mutex.RUnlock()
// Parse IP to validate it // parse ip to validate it
ip := net.ParseIP(addr) ip := net.ParseIP(addr)
if ip == nil { if ip == nil {
// Not a valid IP, return as-is
return addr return addr
} }
// Perform resolution with timeout // perform resolution with timeout
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), r.timeout) ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel() defer cancel()
names, err := net.DefaultResolver.LookupAddr(ctx, addr) names, err := net.DefaultResolver.LookupAddr(ctx, addr)
resolved := addr // fallback to original address resolved := addr
if err == nil && len(names) > 0 { if err == nil && len(names) > 0 {
resolved = names[0] resolved = names[0]
// Remove trailing dot if present // remove trailing dot if present
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' { if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
resolved = resolved[:len(resolved)-1] resolved = resolved[:len(resolved)-1]
} }
} }
// Cache the result elapsed := time.Since(start)
r.mutex.Lock() if debugTiming && elapsed > 50*time.Millisecond {
r.cache[addr] = resolved fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed)
r.mutex.Unlock() }
// cache the result (unless caching is disabled)
if !r.noCache {
r.mutex.Lock()
r.cache[addr] = resolved
r.mutex.Unlock()
}
return resolved return resolved
} }
@@ -71,15 +91,17 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
cacheKey := strconv.Itoa(port) + "/" + proto cacheKey := strconv.Itoa(port) + "/" + proto
// Check cache first // check cache first (unless caching is disabled)
r.mutex.RLock() if !r.noCache {
if cached, exists := r.cache[cacheKey]; exists { r.mutex.RLock()
if cached, exists := r.cache[cacheKey]; exists {
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock() r.mutex.RUnlock()
return cached
} }
r.mutex.RUnlock()
// Perform resolution with timeout // perform resolution with timeout
ctx, cancel := context.WithTimeout(context.Background(), r.timeout) ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel() defer cancel()
@@ -87,16 +109,18 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
resolved := strconv.Itoa(port) // fallback to port number resolved := strconv.Itoa(port) // fallback to port number
if err == nil && service != 0 { if err == nil && service != 0 {
// Try to get service name // try to get service name
if serviceName := getServiceName(port, proto); serviceName != "" { if serviceName := getServiceName(port, proto); serviceName != "" {
resolved = serviceName resolved = serviceName
} }
} }
// Cache the result // cache the result (unless caching is disabled)
r.mutex.Lock() if !r.noCache {
r.cache[cacheKey] = resolved r.mutex.Lock()
r.mutex.Unlock() r.cache[cacheKey] = resolved
r.mutex.Unlock()
}
return resolved return resolved
} }
@@ -159,22 +183,38 @@ func getServiceName(port int, proto string) string {
return "" return ""
} }
// Global resolver instance // global resolver instance
var globalResolver *Resolver var globalResolver *Resolver
// SetGlobalResolver sets the global resolver instance // ResolverOptions configures the global resolver
func SetGlobalResolver(timeout time.Duration) { type ResolverOptions struct {
Timeout time.Duration
NoCache bool
}
// SetGlobalResolver sets the global resolver instance with options
func SetGlobalResolver(opts ResolverOptions) {
timeout := opts.Timeout
if timeout == 0 {
timeout = 200 * time.Millisecond
}
globalResolver = New(timeout) globalResolver = New(timeout)
globalResolver.SetNoCache(opts.NoCache)
} }
// GetGlobalResolver returns the global resolver instance // GetGlobalResolver returns the global resolver instance
func GetGlobalResolver() *Resolver { func GetGlobalResolver() *Resolver {
if globalResolver == nil { if globalResolver == nil {
globalResolver = New(200 * time.Millisecond) // Default timeout globalResolver = New(200 * time.Millisecond)
} }
return globalResolver return globalResolver
} }
// SetNoCache configures whether the global resolver bypasses cache
func SetNoCache(noCache bool) {
GetGlobalResolver().SetNoCache(noCache)
}
// ResolveAddr is a convenience function using the global resolver // ResolveAddr is a convenience function using the global resolver
func ResolveAddr(addr string) string { func ResolveAddr(addr string) string {
return GetGlobalResolver().ResolveAddr(addr) return GetGlobalResolver().ResolveAddr(addr)
@@ -189,3 +229,48 @@ func ResolvePort(port int, proto string) string {
func ResolveAddrPort(addr string, port int, proto string) (string, string) { func ResolveAddrPort(addr string, port int, proto string) (string, string) {
return GetGlobalResolver().ResolveAddrPort(addr, port, proto) return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
} }
// ResolveAddrsParallel resolves multiple addresses concurrently and caches results.
// This should be called before rendering to pre-warm the cache.
func (r *Resolver) ResolveAddrsParallel(addrs []string) {
// dedupe and filter addresses that need resolution
unique := make(map[string]struct{})
for _, addr := range addrs {
if addr == "" || addr == "*" {
continue
}
// skip if already cached
r.mutex.RLock()
_, exists := r.cache[addr]
r.mutex.RUnlock()
if exists {
continue
}
unique[addr] = struct{}{}
}
if len(unique) == 0 {
return
}
var wg sync.WaitGroup
// limit concurrency to avoid overwhelming dns
sem := make(chan struct{}, 32)
for addr := range unique {
wg.Add(1)
go func(a string) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
r.ResolveAddr(a)
}(addr)
}
wg.Wait()
}
// ResolveAddrsParallel is a convenience function using the global resolver
func ResolveAddrsParallel(addrs []string) {
GetGlobalResolver().ResolveAddrsParallel(addrs)
}

View File

@@ -0,0 +1,159 @@
package resolver
import (
"fmt"
"testing"
"time"
)
func BenchmarkResolveAddr_CacheHit(b *testing.B) {
r := New(100 * time.Millisecond)
addr := "127.0.0.1"
// pre-populate cache
r.ResolveAddr(addr)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.ResolveAddr(addr)
}
}
func BenchmarkResolveAddr_CacheMiss(b *testing.B) {
r := New(10 * time.Millisecond) // short timeout for faster benchmarks
b.ResetTimer()
for i := 0; i < b.N; i++ {
// use different addresses to avoid cache hits
addr := fmt.Sprintf("127.0.0.%d", i%256)
r.ClearCache() // clear cache to force miss
r.ResolveAddr(addr)
}
}
func BenchmarkResolveAddr_NoCache(b *testing.B) {
r := New(10 * time.Millisecond)
r.SetNoCache(true)
addr := "127.0.0.1"
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.ResolveAddr(addr)
}
}
func BenchmarkResolvePort_CacheHit(b *testing.B) {
r := New(100 * time.Millisecond)
// pre-populate cache
r.ResolvePort(80, "tcp")
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.ResolvePort(80, "tcp")
}
}
func BenchmarkResolvePort_WellKnown(b *testing.B) {
r := New(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.ClearCache()
r.ResolvePort(443, "tcp")
}
}
func BenchmarkGetServiceName(b *testing.B) {
for i := 0; i < b.N; i++ {
getServiceName(80, "tcp")
}
}
func BenchmarkGetServiceName_NotFound(b *testing.B) {
for i := 0; i < b.N; i++ {
getServiceName(12345, "tcp")
}
}
func BenchmarkResolveAddrsParallel_10(b *testing.B) {
benchmarkResolveAddrsParallel(b, 10)
}
func BenchmarkResolveAddrsParallel_100(b *testing.B) {
benchmarkResolveAddrsParallel(b, 100)
}
func BenchmarkResolveAddrsParallel_1000(b *testing.B) {
benchmarkResolveAddrsParallel(b, 1000)
}
func benchmarkResolveAddrsParallel(b *testing.B, count int) {
addrs := make([]string, count)
for i := 0; i < count; i++ {
addrs[i] = fmt.Sprintf("127.0.%d.%d", i/256, i%256)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
r := New(10 * time.Millisecond)
r.ResolveAddrsParallel(addrs)
}
}
func BenchmarkConcurrentResolveAddr(b *testing.B) {
r := New(100 * time.Millisecond)
addr := "127.0.0.1"
// pre-populate cache
r.ResolveAddr(addr)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
r.ResolveAddr(addr)
}
})
}
func BenchmarkConcurrentResolvePort(b *testing.B) {
r := New(100 * time.Millisecond)
// pre-populate cache
r.ResolvePort(80, "tcp")
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
r.ResolvePort(80, "tcp")
}
})
}
func BenchmarkGetCacheSize(b *testing.B) {
r := New(100 * time.Millisecond)
// populate with some entries
for i := 0; i < 100; i++ {
r.ResolvePort(i+1, "tcp")
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.GetCacheSize()
}
}
func BenchmarkClearCache(b *testing.B) {
r := New(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// populate and clear
for j := 0; j < 10; j++ {
r.ResolvePort(j+1, "tcp")
}
r.ClearCache()
}
}

View File

@@ -0,0 +1,387 @@
package resolver
import (
"sync"
"testing"
"time"
)
func TestNew(t *testing.T) {
r := New(100 * time.Millisecond)
if r == nil {
t.Fatal("expected non-nil resolver")
}
if r.timeout != 100*time.Millisecond {
t.Errorf("expected timeout 100ms, got %v", r.timeout)
}
if r.cache == nil {
t.Error("expected cache to be initialized")
}
if r.noCache {
t.Error("expected noCache to be false by default")
}
}
func TestSetNoCache(t *testing.T) {
r := New(100 * time.Millisecond)
r.SetNoCache(true)
if !r.noCache {
t.Error("expected noCache to be true")
}
r.SetNoCache(false)
if r.noCache {
t.Error("expected noCache to be false")
}
}
func TestResolveAddr_InvalidIP(t *testing.T) {
r := New(100 * time.Millisecond)
// invalid ip should return as-is
result := r.ResolveAddr("not-an-ip")
if result != "not-an-ip" {
t.Errorf("expected 'not-an-ip', got %q", result)
}
// empty string should return as-is
result = r.ResolveAddr("")
if result != "" {
t.Errorf("expected empty string, got %q", result)
}
}
func TestResolveAddr_Caching(t *testing.T) {
r := New(100 * time.Millisecond)
// first call should cache
addr := "127.0.0.1"
result1 := r.ResolveAddr(addr)
// verify cache is populated
if r.GetCacheSize() != 1 {
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
}
// second call should use cache
result2 := r.ResolveAddr(addr)
if result1 != result2 {
t.Errorf("expected same result from cache, got %q and %q", result1, result2)
}
}
func TestResolveAddr_NoCacheMode(t *testing.T) {
r := New(100 * time.Millisecond)
r.SetNoCache(true)
addr := "127.0.0.1"
r.ResolveAddr(addr)
// cache should remain empty when noCache is enabled
if r.GetCacheSize() != 0 {
t.Errorf("expected cache size 0 with noCache, got %d", r.GetCacheSize())
}
}
func TestResolvePort_Zero(t *testing.T) {
r := New(100 * time.Millisecond)
result := r.ResolvePort(0, "tcp")
if result != "0" {
t.Errorf("expected '0' for port 0, got %q", result)
}
}
func TestResolvePort_WellKnown(t *testing.T) {
r := New(100 * time.Millisecond)
tests := []struct {
port int
proto string
expected string
}{
{80, "tcp", "http"},
{443, "tcp", "https"},
{22, "tcp", "ssh"},
{53, "udp", "domain"},
{5432, "tcp", "postgresql"},
}
for _, tt := range tests {
result := r.ResolvePort(tt.port, tt.proto)
if result != tt.expected {
t.Errorf("ResolvePort(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
}
}
}
func TestResolvePort_Caching(t *testing.T) {
r := New(100 * time.Millisecond)
r.ResolvePort(80, "tcp")
r.ResolvePort(443, "tcp")
if r.GetCacheSize() != 2 {
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
}
// same port/proto should not add new entry
r.ResolvePort(80, "tcp")
if r.GetCacheSize() != 2 {
t.Errorf("expected cache size still 2, got %d", r.GetCacheSize())
}
}
func TestResolveAddrPort(t *testing.T) {
r := New(100 * time.Millisecond)
addr, port := r.ResolveAddrPort("127.0.0.1", 80, "tcp")
if addr == "" {
t.Error("expected non-empty address")
}
if port != "http" {
t.Errorf("expected port 'http', got %q", port)
}
}
func TestClearCache(t *testing.T) {
r := New(100 * time.Millisecond)
r.ResolveAddr("127.0.0.1")
r.ResolvePort(80, "tcp")
if r.GetCacheSize() == 0 {
t.Error("expected non-empty cache before clear")
}
r.ClearCache()
if r.GetCacheSize() != 0 {
t.Errorf("expected empty cache after clear, got %d", r.GetCacheSize())
}
}
func TestGetCacheSize(t *testing.T) {
r := New(100 * time.Millisecond)
if r.GetCacheSize() != 0 {
t.Errorf("expected initial cache size 0, got %d", r.GetCacheSize())
}
r.ResolveAddr("127.0.0.1")
if r.GetCacheSize() != 1 {
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
}
}
func TestGetServiceName(t *testing.T) {
tests := []struct {
port int
proto string
expected string
}{
{80, "tcp", "http"},
{443, "tcp", "https"},
{22, "tcp", "ssh"},
{53, "tcp", "domain"},
{53, "udp", "domain"},
{12345, "tcp", ""},
{0, "tcp", ""},
}
for _, tt := range tests {
result := getServiceName(tt.port, tt.proto)
if result != tt.expected {
t.Errorf("getServiceName(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
}
}
}
func TestResolveAddrsParallel(t *testing.T) {
r := New(100 * time.Millisecond)
addrs := []string{
"127.0.0.1",
"127.0.0.2",
"127.0.0.3",
"", // should be skipped
"*", // should be skipped
}
r.ResolveAddrsParallel(addrs)
// should have cached 3 addresses (excluding empty and *)
if r.GetCacheSize() != 3 {
t.Errorf("expected cache size 3, got %d", r.GetCacheSize())
}
}
func TestResolveAddrsParallel_Dedupe(t *testing.T) {
r := New(100 * time.Millisecond)
addrs := []string{
"127.0.0.1",
"127.0.0.1",
"127.0.0.1",
"127.0.0.2",
}
r.ResolveAddrsParallel(addrs)
// should have cached 2 unique addresses
if r.GetCacheSize() != 2 {
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
}
}
func TestResolveAddrsParallel_SkipsCached(t *testing.T) {
r := New(100 * time.Millisecond)
// pre-cache one address
r.ResolveAddr("127.0.0.1")
addrs := []string{
"127.0.0.1", // already cached
"127.0.0.2", // not cached
}
initialSize := r.GetCacheSize()
r.ResolveAddrsParallel(addrs)
// should have added 1 more
if r.GetCacheSize() != initialSize+1 {
t.Errorf("expected cache size %d, got %d", initialSize+1, r.GetCacheSize())
}
}
func TestResolveAddrsParallel_Empty(t *testing.T) {
r := New(100 * time.Millisecond)
// should not panic with empty input
r.ResolveAddrsParallel([]string{})
r.ResolveAddrsParallel(nil)
if r.GetCacheSize() != 0 {
t.Errorf("expected cache size 0, got %d", r.GetCacheSize())
}
}
func TestGlobalResolver(t *testing.T) {
// reset global resolver
globalResolver = nil
r := GetGlobalResolver()
if r == nil {
t.Fatal("expected non-nil global resolver")
}
// should return same instance
r2 := GetGlobalResolver()
if r != r2 {
t.Error("expected same global resolver instance")
}
}
func TestSetGlobalResolver(t *testing.T) {
SetGlobalResolver(ResolverOptions{
Timeout: 500 * time.Millisecond,
NoCache: true,
})
r := GetGlobalResolver()
if r.timeout != 500*time.Millisecond {
t.Errorf("expected timeout 500ms, got %v", r.timeout)
}
if !r.noCache {
t.Error("expected noCache to be true")
}
// reset for other tests
globalResolver = nil
}
func TestSetGlobalResolver_DefaultTimeout(t *testing.T) {
SetGlobalResolver(ResolverOptions{
Timeout: 0, // should use default
})
r := GetGlobalResolver()
if r.timeout != 200*time.Millisecond {
t.Errorf("expected default timeout 200ms, got %v", r.timeout)
}
// reset for other tests
globalResolver = nil
}
func TestGlobalConvenienceFunctions(t *testing.T) {
globalResolver = nil
// test global ResolveAddr
result := ResolveAddr("127.0.0.1")
if result == "" {
t.Error("expected non-empty result from global ResolveAddr")
}
// test global ResolvePort
port := ResolvePort(80, "tcp")
if port != "http" {
t.Errorf("expected 'http', got %q", port)
}
// test global ResolveAddrPort
addr, portStr := ResolveAddrPort("127.0.0.1", 443, "tcp")
if addr == "" {
t.Error("expected non-empty address")
}
if portStr != "https" {
t.Errorf("expected 'https', got %q", portStr)
}
// test global SetNoCache
SetNoCache(true)
if !GetGlobalResolver().noCache {
t.Error("expected global noCache to be true")
}
// reset
globalResolver = nil
}
func TestConcurrentAccess(t *testing.T) {
r := New(100 * time.Millisecond)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
addr := "127.0.0.1"
r.ResolveAddr(addr)
r.ResolvePort(80+n%10, "tcp")
r.GetCacheSize()
}(i)
}
wg.Wait()
// should not panic and cache should have entries
if r.GetCacheSize() == 0 {
t.Error("expected non-empty cache after concurrent access")
}
}
func TestResolveAddr_TrailingDot(t *testing.T) {
// this test verifies the trailing dot removal logic
// by checking the internal logic works correctly
r := New(100 * time.Millisecond)
// localhost should resolve and have trailing dot removed
result := r.ResolveAddr("127.0.0.1")
if len(result) > 0 && result[len(result)-1] == '.' {
t.Error("expected trailing dot to be removed")
}
}

133
internal/state/state.go Normal file
View File

@@ -0,0 +1,133 @@
package state
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"github.com/karol-broda/snitch/internal/collector"
)
// TUIState holds view options that can be persisted between sessions
type TUIState struct {
ShowTCP bool `json:"show_tcp"`
ShowUDP bool `json:"show_udp"`
ShowListening bool `json:"show_listening"`
ShowEstablished bool `json:"show_established"`
ShowOther bool `json:"show_other"`
SortField collector.SortField `json:"sort_field"`
SortReverse bool `json:"sort_reverse"`
ResolveAddrs bool `json:"resolve_addrs"`
ResolvePorts bool `json:"resolve_ports"`
}
var (
saveMu sync.Mutex
saveChan chan TUIState
once sync.Once
)
// Path returns the XDG-compliant state file path
func Path() string {
stateDir := os.Getenv("XDG_STATE_HOME")
if stateDir == "" {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
stateDir = filepath.Join(home, ".local", "state")
}
return filepath.Join(stateDir, "snitch", "tui.json")
}
// Load reads the TUI state from disk.
// returns nil if state file doesn't exist or can't be read.
func Load() *TUIState {
path := Path()
if path == "" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return nil
}
var state TUIState
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
return &state
}
// Save writes the TUI state to disk synchronously.
// creates parent directories if needed.
func Save(state TUIState) error {
path := Path()
if path == "" {
return nil
}
saveMu.Lock()
defer saveMu.Unlock()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0644)
}
// SaveAsync queues a state save to happen in the background.
// only the most recent state is saved if multiple saves are queued.
func SaveAsync(state TUIState) {
once.Do(func() {
saveChan = make(chan TUIState, 1)
go saveWorker()
})
// non-blocking send, replace pending save with newer state
select {
case saveChan <- state:
default:
// channel full, drain and replace
select {
case <-saveChan:
default:
}
select {
case saveChan <- state:
default:
}
}
}
func saveWorker() {
for state := range saveChan {
_ = Save(state)
}
}
// Default returns a TUIState with default values
func Default() TUIState {
return TUIState{
ShowTCP: true,
ShowUDP: true,
ShowListening: true,
ShowEstablished: true,
ShowOther: true,
SortField: collector.SortByLport,
SortReverse: false,
ResolveAddrs: false,
ResolvePorts: false,
}
}

View File

@@ -0,0 +1,236 @@
package state
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/karol-broda/snitch/internal/collector"
)
func TestPath_XDGStateHome(t *testing.T) {
t.Setenv("XDG_STATE_HOME", "/custom/state")
path := Path()
expected := "/custom/state/snitch/tui.json"
if path != expected {
t.Errorf("Path() = %q, want %q", path, expected)
}
}
func TestPath_DefaultFallback(t *testing.T) {
t.Setenv("XDG_STATE_HOME", "")
path := Path()
home, err := os.UserHomeDir()
if err != nil {
t.Skip("cannot determine home directory")
}
expected := filepath.Join(home, ".local", "state", "snitch", "tui.json")
if path != expected {
t.Errorf("Path() = %q, want %q", path, expected)
}
}
func TestDefault(t *testing.T) {
d := Default()
if d.ShowTCP != true {
t.Error("expected ShowTCP to be true")
}
if d.ShowUDP != true {
t.Error("expected ShowUDP to be true")
}
if d.ShowListening != true {
t.Error("expected ShowListening to be true")
}
if d.ShowEstablished != true {
t.Error("expected ShowEstablished to be true")
}
if d.ShowOther != true {
t.Error("expected ShowOther to be true")
}
if d.SortField != collector.SortByLport {
t.Errorf("expected SortField to be %q, got %q", collector.SortByLport, d.SortField)
}
if d.SortReverse != false {
t.Error("expected SortReverse to be false")
}
if d.ResolveAddrs != false {
t.Error("expected ResolveAddrs to be false")
}
if d.ResolvePorts != false {
t.Error("expected ResolvePorts to be false")
}
}
func TestSaveAndLoad(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
state := TUIState{
ShowTCP: false,
ShowUDP: true,
ShowListening: true,
ShowEstablished: false,
ShowOther: true,
SortField: collector.SortByProcess,
SortReverse: true,
ResolveAddrs: true,
ResolvePorts: false,
}
err := Save(state)
if err != nil {
t.Fatalf("Save() error = %v", err)
}
// verify file was created
path := Path()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Fatal("expected state file to exist after Save()")
}
loaded := Load()
if loaded == nil {
t.Fatal("Load() returned nil")
}
if loaded.ShowTCP != state.ShowTCP {
t.Errorf("ShowTCP = %v, want %v", loaded.ShowTCP, state.ShowTCP)
}
if loaded.ShowUDP != state.ShowUDP {
t.Errorf("ShowUDP = %v, want %v", loaded.ShowUDP, state.ShowUDP)
}
if loaded.ShowListening != state.ShowListening {
t.Errorf("ShowListening = %v, want %v", loaded.ShowListening, state.ShowListening)
}
if loaded.ShowEstablished != state.ShowEstablished {
t.Errorf("ShowEstablished = %v, want %v", loaded.ShowEstablished, state.ShowEstablished)
}
if loaded.ShowOther != state.ShowOther {
t.Errorf("ShowOther = %v, want %v", loaded.ShowOther, state.ShowOther)
}
if loaded.SortField != state.SortField {
t.Errorf("SortField = %v, want %v", loaded.SortField, state.SortField)
}
if loaded.SortReverse != state.SortReverse {
t.Errorf("SortReverse = %v, want %v", loaded.SortReverse, state.SortReverse)
}
if loaded.ResolveAddrs != state.ResolveAddrs {
t.Errorf("ResolveAddrs = %v, want %v", loaded.ResolveAddrs, state.ResolveAddrs)
}
if loaded.ResolvePorts != state.ResolvePorts {
t.Errorf("ResolvePorts = %v, want %v", loaded.ResolvePorts, state.ResolvePorts)
}
}
func TestLoad_NonExistent(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
loaded := Load()
if loaded != nil {
t.Error("expected Load() to return nil for non-existent file")
}
}
func TestLoad_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
// create directory and invalid json file
stateDir := filepath.Join(tmpDir, "snitch")
if err := os.MkdirAll(stateDir, 0755); err != nil {
t.Fatal(err)
}
stateFile := filepath.Join(stateDir, "tui.json")
if err := os.WriteFile(stateFile, []byte("not valid json"), 0644); err != nil {
t.Fatal(err)
}
loaded := Load()
if loaded != nil {
t.Error("expected Load() to return nil for invalid JSON")
}
}
func TestSave_CreatesDirectories(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
// snitch directory should not exist yet
snitchDir := filepath.Join(tmpDir, "snitch")
if _, err := os.Stat(snitchDir); err == nil {
t.Fatal("expected snitch directory to not exist initially")
}
err := Save(Default())
if err != nil {
t.Fatalf("Save() error = %v", err)
}
// directory should now exist
if _, err := os.Stat(snitchDir); os.IsNotExist(err) {
t.Error("expected Save() to create parent directories")
}
}
func TestSaveAsync(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
state := TUIState{
ShowTCP: false,
SortField: collector.SortByPID,
}
SaveAsync(state)
// wait for background save with timeout
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if loaded := Load(); loaded != nil {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Log("SaveAsync may not have completed in time (non-fatal in CI)")
}
func TestTUIState_JSONRoundtrip(t *testing.T) {
// verify all sort fields serialize correctly
sortFields := []collector.SortField{
collector.SortByLport,
collector.SortByProcess,
collector.SortByPID,
collector.SortByState,
collector.SortByProto,
}
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
for _, sf := range sortFields {
state := TUIState{
ShowTCP: true,
SortField: sf,
}
if err := Save(state); err != nil {
t.Fatalf("Save() error for %q: %v", sf, err)
}
loaded := Load()
if loaded == nil {
t.Fatalf("Load() returned nil for %q", sf)
}
if loaded.SortField != sf {
t.Errorf("SortField roundtrip failed: got %q, want %q", loaded.SortField, sf)
}
}
}

View File

@@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/karol-broda/snitch/internal/collector" "github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/errutil"
) )
// TestCollector wraps MockCollector for use in tests // TestCollector wraps MockCollector for use in tests
@@ -47,13 +48,13 @@ func SetupTestEnvironment(t *testing.T) (string, func()) {
oldConfig := os.Getenv("SNITCH_CONFIG") oldConfig := os.Getenv("SNITCH_CONFIG")
oldNoColor := os.Getenv("SNITCH_NO_COLOR") oldNoColor := os.Getenv("SNITCH_NO_COLOR")
os.Setenv("SNITCH_NO_COLOR", "1") // Disable colors in tests errutil.Setenv("SNITCH_NO_COLOR", "1")
// Cleanup function // Cleanup function
cleanup := func() { cleanup := func() {
os.RemoveAll(tempDir) errutil.RemoveAll(tempDir)
os.Setenv("SNITCH_CONFIG", oldConfig) errutil.Setenv("SNITCH_CONFIG", oldConfig)
os.Setenv("SNITCH_NO_COLOR", oldNoColor) errutil.Setenv("SNITCH_NO_COLOR", oldNoColor)
} }
return tempDir, cleanup return tempDir, cleanup
@@ -192,8 +193,8 @@ func (oc *OutputCapture) Stop() (string, string, error) {
os.Stderr = oc.oldStderr os.Stderr = oc.oldStderr
// Close files // Close files
oc.stdout.Close() errutil.Close(oc.stdout)
oc.stderr.Close() errutil.Close(oc.stderr)
// Read captured content // Read captured content
stdoutContent, err := os.ReadFile(oc.stdoutFile) stdoutContent, err := os.ReadFile(oc.stdoutFile)
@@ -207,9 +208,9 @@ func (oc *OutputCapture) Stop() (string, string, error) {
} }
// Cleanup // Cleanup
os.Remove(oc.stdoutFile) errutil.Remove(oc.stdoutFile)
os.Remove(oc.stderrFile) errutil.Remove(oc.stderrFile)
os.Remove(filepath.Dir(oc.stdoutFile)) errutil.Remove(filepath.Dir(oc.stdoutFile))
return string(stdoutContent), string(stderrContent), nil return string(stdoutContent), string(stderrContent), nil
} }

24
internal/theme/ansi.go Normal file
View File

@@ -0,0 +1,24 @@
package theme
// ANSI palette uses standard terminal colors (0-15)
// this allows the theme to inherit from the user's terminal color scheme
var paletteANSI = Palette{
Name: "ansi",
Fg: "15", // bright white
FgMuted: "7", // white
FgSubtle: "8", // bright black (gray)
Bg: "0", // black
BgMuted: "0", // black
Border: "8", // bright black (gray)
Red: "1", // red
Green: "2", // green
Yellow: "3", // yellow
Blue: "4", // blue
Magenta: "5", // magenta
Cyan: "6", // cyan
Orange: "3", // yellow (ansi has no orange, fallback to yellow)
Gray: "8", // bright black
}

View File

@@ -0,0 +1,87 @@
package theme
// catppuccin mocha (dark)
// https://github.com/catppuccin/catppuccin
var paletteCatppuccinMocha = Palette{
Name: "catppuccin-mocha",
Fg: "#cdd6f4", // text
FgMuted: "#a6adc8", // subtext0
FgSubtle: "#6c7086", // overlay0
Bg: "#1e1e2e", // base
BgMuted: "#313244", // surface0
Border: "#45475a", // surface1
Red: "#f38ba8",
Green: "#a6e3a1",
Yellow: "#f9e2af",
Blue: "#89b4fa",
Magenta: "#cba6f7", // mauve
Cyan: "#94e2d5", // teal
Orange: "#fab387", // peach
Gray: "#585b70", // surface2
}
// catppuccin macchiato (medium-dark)
var paletteCatppuccinMacchiato = Palette{
Name: "catppuccin-macchiato",
Fg: "#cad3f5", // text
FgMuted: "#a5adcb", // subtext0
FgSubtle: "#6e738d", // overlay0
Bg: "#24273a", // base
BgMuted: "#363a4f", // surface0
Border: "#494d64", // surface1
Red: "#ed8796",
Green: "#a6da95",
Yellow: "#eed49f",
Blue: "#8aadf4",
Magenta: "#c6a0f6", // mauve
Cyan: "#8bd5ca", // teal
Orange: "#f5a97f", // peach
Gray: "#5b6078", // surface2
}
// catppuccin frappe (medium)
var paletteCatppuccinFrappe = Palette{
Name: "catppuccin-frappe",
Fg: "#c6d0f5", // text
FgMuted: "#a5adce", // subtext0
FgSubtle: "#737994", // overlay0
Bg: "#303446", // base
BgMuted: "#414559", // surface0
Border: "#51576d", // surface1
Red: "#e78284",
Green: "#a6d189",
Yellow: "#e5c890",
Blue: "#8caaee",
Magenta: "#ca9ee6", // mauve
Cyan: "#81c8be", // teal
Orange: "#ef9f76", // peach
Gray: "#626880", // surface2
}
// catppuccin latte (light)
var paletteCatppuccinLatte = Palette{
Name: "catppuccin-latte",
Fg: "#4c4f69", // text
FgMuted: "#6c6f85", // subtext0
FgSubtle: "#9ca0b0", // overlay0
Bg: "#eff1f5", // base
BgMuted: "#ccd0da", // surface0
Border: "#bcc0cc", // surface1
Red: "#d20f39",
Green: "#40a02b",
Yellow: "#df8e1d",
Blue: "#1e66f5",
Magenta: "#8839ef", // mauve
Cyan: "#179299", // teal
Orange: "#fe640b", // peach
Gray: "#acb0be", // surface2
}

24
internal/theme/dracula.go Normal file
View File

@@ -0,0 +1,24 @@
package theme
// dracula theme
// https://draculatheme.com/
var paletteDracula = Palette{
Name: "dracula",
Fg: "#f8f8f2", // foreground
FgMuted: "#f8f8f2", // foreground
FgSubtle: "#6272a4", // comment
Bg: "#282a36", // background
BgMuted: "#44475a", // selection
Border: "#44475a", // selection
Red: "#ff5555",
Green: "#50fa7b",
Yellow: "#f1fa8c",
Blue: "#6272a4", // dracula uses comment color for blue tones
Magenta: "#bd93f9", // purple
Cyan: "#8be9fd",
Orange: "#ffb86c",
Gray: "#6272a4", // comment
}

45
internal/theme/gruvbox.go Normal file
View File

@@ -0,0 +1,45 @@
package theme
// gruvbox dark
// https://github.com/morhetz/gruvbox
var paletteGruvboxDark = Palette{
Name: "gruvbox-dark",
Fg: "#ebdbb2", // fg
FgMuted: "#d5c4a1", // fg2
FgSubtle: "#a89984", // fg4
Bg: "#282828", // bg
BgMuted: "#3c3836", // bg1
Border: "#504945", // bg2
Red: "#fb4934",
Green: "#b8bb26",
Yellow: "#fabd2f",
Blue: "#83a598",
Magenta: "#d3869b", // purple
Cyan: "#8ec07c", // aqua
Orange: "#fe8019",
Gray: "#928374",
}
// gruvbox light
var paletteGruvboxLight = Palette{
Name: "gruvbox-light",
Fg: "#3c3836", // fg
FgMuted: "#504945", // fg2
FgSubtle: "#7c6f64", // fg4
Bg: "#fbf1c7", // bg
BgMuted: "#ebdbb2", // bg1
Border: "#d5c4a1", // bg2
Red: "#cc241d",
Green: "#98971a",
Yellow: "#d79921",
Blue: "#458588",
Magenta: "#b16286", // purple
Cyan: "#689d6a", // aqua
Orange: "#d65d0e",
Gray: "#928374",
}

49
internal/theme/mono.go Normal file
View File

@@ -0,0 +1,49 @@
package theme
import "github.com/charmbracelet/lipgloss"
// createMonoTheme creates a monochrome theme (no colors)
// useful for accessibility, piping output, or minimal terminals
func createMonoTheme() *Theme {
baseStyle := lipgloss.NewStyle()
boldStyle := lipgloss.NewStyle().Bold(true)
return &Theme{
Name: "mono",
Styles: Styles{
Header: boldStyle,
Border: baseStyle,
Selected: boldStyle,
Watched: boldStyle,
Normal: baseStyle,
Error: boldStyle,
Success: boldStyle,
Warning: boldStyle,
Footer: baseStyle,
Background: baseStyle,
Proto: ProtoStyles{
TCP: baseStyle,
UDP: baseStyle,
Unix: baseStyle,
TCP6: baseStyle,
UDP6: baseStyle,
},
State: StateStyles{
Listen: baseStyle,
Established: baseStyle,
TimeWait: baseStyle,
CloseWait: baseStyle,
SynSent: baseStyle,
SynRecv: baseStyle,
FinWait1: baseStyle,
FinWait2: baseStyle,
Closing: baseStyle,
LastAck: baseStyle,
Closed: baseStyle,
},
},
}
}

24
internal/theme/nord.go Normal file
View File

@@ -0,0 +1,24 @@
package theme
// nord theme
// https://www.nordtheme.com/
var paletteNord = Palette{
Name: "nord",
Fg: "#eceff4", // snow storm - nord6
FgMuted: "#d8dee9", // snow storm - nord4
FgSubtle: "#4c566a", // polar night - nord3
Bg: "#2e3440", // polar night - nord0
BgMuted: "#3b4252", // polar night - nord1
Border: "#434c5e", // polar night - nord2
Red: "#bf616a", // aurora - nord11
Green: "#a3be8c", // aurora - nord14
Yellow: "#ebcb8b", // aurora - nord13
Blue: "#81a1c1", // frost - nord9
Magenta: "#b48ead", // aurora - nord15
Cyan: "#88c0d0", // frost - nord8
Orange: "#d08770", // aurora - nord12
Gray: "#4c566a", // polar night - nord3
}

View File

@@ -0,0 +1,24 @@
package theme
// one dark theme (atom editor)
// https://github.com/atom/atom/tree/master/packages/one-dark-syntax
var paletteOneDark = Palette{
Name: "one-dark",
Fg: "#abb2bf", // foreground
FgMuted: "#9da5b4", // foreground muted
FgSubtle: "#5c6370", // comment
Bg: "#282c34", // background
BgMuted: "#21252b", // gutter background
Border: "#3e4451", // selection
Red: "#e06c75",
Green: "#98c379",
Yellow: "#e5c07b",
Blue: "#61afef",
Magenta: "#c678dd", // purple
Cyan: "#56b6c2",
Orange: "#d19a66",
Gray: "#5c6370", // comment
}

111
internal/theme/palette.go Normal file
View File

@@ -0,0 +1,111 @@
package theme
import (
"strconv"
"github.com/charmbracelet/lipgloss"
)
// Palette defines the semantic colors for a theme
type Palette struct {
Name string
// base colors
Fg string // primary foreground
FgMuted string // secondary/muted foreground
FgSubtle string // subtle/disabled foreground
Bg string // primary background
BgMuted string // secondary background (selections, highlights)
Border string // border color
// semantic colors
Red string
Green string
Yellow string
Blue string
Magenta string
Cyan string
Orange string
Gray string
}
// Color converts a palette color string to a lipgloss.TerminalColor.
// If the string is 1-2 characters, it's treated as an ANSI color code.
// Otherwise, it's treated as a hex color.
func (p *Palette) Color(c string) lipgloss.TerminalColor {
if c == "" {
return lipgloss.NoColor{}
}
if len(c) <= 2 {
n, err := strconv.Atoi(c)
if err == nil {
return lipgloss.ANSIColor(n)
}
}
return lipgloss.Color(c)
}
// ToTheme converts a Palette to a Theme with lipgloss styles
func (p *Palette) ToTheme() *Theme {
return &Theme{
Name: p.Name,
Styles: Styles{
Header: lipgloss.NewStyle().
Bold(true).
Foreground(p.Color(p.Fg)),
Border: lipgloss.NewStyle().
Foreground(p.Color(p.Border)),
Selected: lipgloss.NewStyle().
Bold(true).
Foreground(p.Color(p.Fg)),
Watched: lipgloss.NewStyle().
Bold(true).
Foreground(p.Color(p.Orange)),
Normal: lipgloss.NewStyle().
Foreground(p.Color(p.FgMuted)),
Error: lipgloss.NewStyle().
Foreground(p.Color(p.Red)),
Success: lipgloss.NewStyle().
Foreground(p.Color(p.Green)),
Warning: lipgloss.NewStyle().
Foreground(p.Color(p.Yellow)),
Footer: lipgloss.NewStyle().
Foreground(p.Color(p.FgSubtle)),
Background: lipgloss.NewStyle(),
Proto: ProtoStyles{
TCP: lipgloss.NewStyle().Foreground(p.Color(p.Green)),
UDP: lipgloss.NewStyle().Foreground(p.Color(p.Magenta)),
Unix: lipgloss.NewStyle().Foreground(p.Color(p.Gray)),
TCP6: lipgloss.NewStyle().Foreground(p.Color(p.Cyan)),
UDP6: lipgloss.NewStyle().Foreground(p.Color(p.Blue)),
},
State: StateStyles{
Listen: lipgloss.NewStyle().Foreground(p.Color(p.Green)),
Established: lipgloss.NewStyle().Foreground(p.Color(p.Blue)),
TimeWait: lipgloss.NewStyle().Foreground(p.Color(p.Yellow)),
CloseWait: lipgloss.NewStyle().Foreground(p.Color(p.Orange)),
SynSent: lipgloss.NewStyle().Foreground(p.Color(p.Magenta)),
SynRecv: lipgloss.NewStyle().Foreground(p.Color(p.Magenta)),
FinWait1: lipgloss.NewStyle().Foreground(p.Color(p.Red)),
FinWait2: lipgloss.NewStyle().Foreground(p.Color(p.Red)),
Closing: lipgloss.NewStyle().Foreground(p.Color(p.Red)),
LastAck: lipgloss.NewStyle().Foreground(p.Color(p.Red)),
Closed: lipgloss.NewStyle().Foreground(p.Color(p.Gray)),
},
},
}
}

14
internal/theme/readme.md Normal file
View File

@@ -0,0 +1,14 @@
# theme Palettes
the color palettes in this directory were generated by an LLM agent (Claude Opus 4.5) using web search to fetch the official color specifications from each themes documentation
as it is with llm agents its possible the colors may be wrong
Sources:
- [Catppuccin](https://github.com/catppuccin/catppuccin)
- [Dracula](https://draculatheme.com/)
- [Gruvbox](https://github.com/morhetz/gruvbox)
- [Nord](https://www.nordtheme.com/)
- [One Dark](https://github.com/atom/one-dark-syntax)
- [Solarized](https://ethanschoonover.com/solarized/)
- [Tokyo Night](https://github.com/enkia/tokyo-night-vscode-theme)

View File

@@ -0,0 +1,45 @@
package theme
// solarized dark theme
// https://ethanschoonover.com/solarized/
var paletteSolarizedDark = Palette{
Name: "solarized-dark",
Fg: "#839496", // base0
FgMuted: "#93a1a1", // base1
FgSubtle: "#586e75", // base01
Bg: "#002b36", // base03
BgMuted: "#073642", // base02
Border: "#073642", // base02
Red: "#dc322f",
Green: "#859900",
Yellow: "#b58900",
Blue: "#268bd2",
Magenta: "#d33682",
Cyan: "#2aa198",
Orange: "#cb4b16",
Gray: "#657b83", // base00
}
// solarized light theme
var paletteSolarizedLight = Palette{
Name: "solarized-light",
Fg: "#657b83", // base00
FgMuted: "#586e75", // base01
FgSubtle: "#93a1a1", // base1
Bg: "#fdf6e3", // base3
BgMuted: "#eee8d5", // base2
Border: "#eee8d5", // base2
Red: "#dc322f",
Green: "#859900",
Yellow: "#b58900",
Blue: "#268bd2",
Magenta: "#d33682",
Cyan: "#2aa198",
Orange: "#cb4b16",
Gray: "#839496", // base0
}

View File

@@ -1,6 +1,7 @@
package theme package theme
import ( import (
"sort"
"strings" "strings"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
@@ -52,152 +53,73 @@ type StateStyles struct {
Closed lipgloss.Style Closed lipgloss.Style
} }
var ( var themes map[string]*Theme
themes map[string]*Theme
)
func init() { func init() {
themes = map[string]*Theme{ themes = make(map[string]*Theme)
"default": createAdaptiveTheme(),
"mono": createMonoTheme(), // ansi theme (default) - inherits from terminal colors
} themes["ansi"] = paletteANSI.ToTheme()
// catppuccin variants
themes["catppuccin-mocha"] = paletteCatppuccinMocha.ToTheme()
themes["catppuccin-macchiato"] = paletteCatppuccinMacchiato.ToTheme()
themes["catppuccin-frappe"] = paletteCatppuccinFrappe.ToTheme()
themes["catppuccin-latte"] = paletteCatppuccinLatte.ToTheme()
// gruvbox variants
themes["gruvbox-dark"] = paletteGruvboxDark.ToTheme()
themes["gruvbox-light"] = paletteGruvboxLight.ToTheme()
// dracula
themes["dracula"] = paletteDracula.ToTheme()
// nord
themes["nord"] = paletteNord.ToTheme()
// tokyo night variants
themes["tokyo-night"] = paletteTokyoNight.ToTheme()
themes["tokyo-night-storm"] = paletteTokyoNightStorm.ToTheme()
themes["tokyo-night-light"] = paletteTokyoNightLight.ToTheme()
// solarized variants
themes["solarized-dark"] = paletteSolarizedDark.ToTheme()
themes["solarized-light"] = paletteSolarizedLight.ToTheme()
// one dark
themes["one-dark"] = paletteOneDark.ToTheme()
// monochrome (no colors)
themes["mono"] = createMonoTheme()
} }
// GetTheme returns a theme by name, with auto-detection support // DefaultTheme is the theme used when none is specified
const DefaultTheme = "ansi"
// GetTheme returns a theme by name
func GetTheme(name string) *Theme { func GetTheme(name string) *Theme {
if name == "auto" { if name == "" || name == "auto" || name == "default" {
// lipgloss handles adaptive colors, so we just return the default return themes[DefaultTheme]
return themes["default"]
} }
if theme, exists := themes[name]; exists { if theme, exists := themes[name]; exists {
return theme return theme
} }
// a specific theme was requested (e.g. "dark", "light"), but we now use adaptive
// so we can just return the default theme and lipgloss will handle it
if name == "dark" || name == "light" {
return themes["default"]
}
// fallback to default // fallback to default
return themes["default"] return themes[DefaultTheme]
} }
// ListThemes returns available theme names // ListThemes returns available theme names sorted alphabetically
func ListThemes() []string { func ListThemes() []string {
var names []string names := make([]string, 0, len(themes))
for name := range themes { for name := range themes {
names = append(names, name) names = append(names, name)
} }
sort.Strings(names)
return names return names
} }
// createAdaptiveTheme creates a clean, minimal theme
func createAdaptiveTheme() *Theme {
return &Theme{
Name: "default",
Styles: Styles{
Header: lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#1F2937", Dark: "#F9FAFB"}),
Watched: lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#F59E0B"}),
Border: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#D1D5DB", Dark: "#374151"}),
Selected: lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "#1F2937", Dark: "#F9FAFB"}),
Normal: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#6B7280", Dark: "#9CA3AF"}),
Error: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
Success: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
Warning: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
Footer: lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#9CA3AF", Dark: "#6B7280"}),
Background: lipgloss.NewStyle(),
Proto: ProtoStyles{
TCP: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
UDP: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
Unix: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#6B7280", Dark: "#9CA3AF"}),
TCP6: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
UDP6: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
},
State: StateStyles{
Listen: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#059669", Dark: "#34D399"}),
Established: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#2563EB", Dark: "#60A5FA"}),
TimeWait: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
CloseWait: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#D97706", Dark: "#FBBF24"}),
SynSent: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
SynRecv: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#7C3AED", Dark: "#A78BFA"}),
FinWait1: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
FinWait2: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
Closing: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
LastAck: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#DC2626", Dark: "#F87171"}),
Closed: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#9CA3AF", Dark: "#6B7280"}),
},
},
}
}
// createMonoTheme creates a monochrome theme (no colors)
func createMonoTheme() *Theme {
baseStyle := lipgloss.NewStyle()
boldStyle := lipgloss.NewStyle().Bold(true)
return &Theme{
Name: "mono",
Styles: Styles{
Header: boldStyle,
Border: baseStyle,
Selected: boldStyle,
Normal: baseStyle,
Error: boldStyle,
Success: boldStyle,
Warning: boldStyle,
Footer: baseStyle,
Background: baseStyle,
Proto: ProtoStyles{
TCP: baseStyle,
UDP: baseStyle,
Unix: baseStyle,
TCP6: baseStyle,
UDP6: baseStyle,
},
State: StateStyles{
Listen: baseStyle,
Established: baseStyle,
TimeWait: baseStyle,
CloseWait: baseStyle,
SynSent: baseStyle,
SynRecv: baseStyle,
FinWait1: baseStyle,
FinWait2: baseStyle,
Closing: baseStyle,
LastAck: baseStyle,
Closed: baseStyle,
},
},
}
}
// GetProtoStyle returns the appropriate style for a protocol // GetProtoStyle returns the appropriate style for a protocol
func (s *Styles) GetProtoStyle(proto string) lipgloss.Style { func (s *Styles) GetProtoStyle(proto string) lipgloss.Style {
switch strings.ToLower(proto) { switch strings.ToLower(proto) {

View File

@@ -0,0 +1,66 @@
package theme
// tokyo night theme
// https://github.com/enkia/tokyo-night-vscode-theme
var paletteTokyoNight = Palette{
Name: "tokyo-night",
Fg: "#c0caf5", // foreground
FgMuted: "#a9b1d6", // foreground dark
FgSubtle: "#565f89", // comment
Bg: "#1a1b26", // background
BgMuted: "#24283b", // background highlight
Border: "#414868", // border
Red: "#f7768e",
Green: "#9ece6a",
Yellow: "#e0af68",
Blue: "#7aa2f7",
Magenta: "#bb9af7", // purple
Cyan: "#7dcfff",
Orange: "#ff9e64",
Gray: "#565f89", // comment
}
// tokyo night storm variant
var paletteTokyoNightStorm = Palette{
Name: "tokyo-night-storm",
Fg: "#c0caf5", // foreground
FgMuted: "#a9b1d6", // foreground dark
FgSubtle: "#565f89", // comment
Bg: "#24283b", // background (storm is slightly lighter)
BgMuted: "#1f2335", // background dark
Border: "#414868", // border
Red: "#f7768e",
Green: "#9ece6a",
Yellow: "#e0af68",
Blue: "#7aa2f7",
Magenta: "#bb9af7", // purple
Cyan: "#7dcfff",
Orange: "#ff9e64",
Gray: "#565f89", // comment
}
// tokyo night light variant
var paletteTokyoNightLight = Palette{
Name: "tokyo-night-light",
Fg: "#343b58", // foreground
FgMuted: "#565a6e", // foreground dark
FgSubtle: "#9699a3", // comment
Bg: "#d5d6db", // background
BgMuted: "#cbccd1", // background highlight
Border: "#b4b5b9", // border
Red: "#8c4351",
Green: "#485e30",
Yellow: "#8f5e15",
Blue: "#34548a",
Magenta: "#5a4a78", // purple
Cyan: "#0f4b6e",
Orange: "#965027",
Gray: "#9699a3", // comment
}

View File

@@ -38,6 +38,10 @@ func sortFieldLabel(f collector.SortField) string {
return "state" return "state"
case collector.SortByProto: case collector.SortByProto:
return "proto" return "proto"
case collector.SortByRaddr:
return "raddr"
case collector.SortByRport:
return "rport"
default: default:
return "port" return "port"
} }

View File

@@ -2,10 +2,12 @@ package tui
import ( import (
"fmt" "fmt"
"github.com/karol-broda/snitch/internal/collector" "strings"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/karol-broda/snitch/internal/collector"
) )
func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
@@ -14,6 +16,11 @@ func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m.handleSearchKey(msg) return m.handleSearchKey(msg)
} }
// export modal captures all input
if m.showExportModal {
return m.handleExportKey(msg)
}
// kill confirmation dialog // kill confirmation dialog
if m.showKillConfirm { if m.showKillConfirm {
return m.handleKillConfirmKey(msg) return m.handleKillConfirmKey(msg)
@@ -52,6 +59,82 @@ func (m model) handleSearchKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m, nil return m, nil
} }
func (m model) handleExportKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "esc":
m.showExportModal = false
m.exportFilename = ""
m.exportFormat = ""
m.exportError = ""
case "tab":
// toggle format
if m.exportFormat == "tsv" {
m.exportFormat = "csv"
} else {
m.exportFormat = "tsv"
}
m.exportError = ""
case "enter":
// build final filename with extension
filename := m.exportFilename
if filename == "" {
filename = "connections"
}
ext := ".csv"
if m.exportFormat == "tsv" {
ext = ".tsv"
}
// only add extension if not already present
if !strings.HasSuffix(strings.ToLower(filename), ".csv") &&
!strings.HasSuffix(strings.ToLower(filename), ".tsv") {
filename = filename + ext
}
m.exportFilename = filename
err := m.exportConnections()
if err != nil {
m.exportError = err.Error()
return m, nil
}
visible := m.visibleConnections()
m.statusMessage = fmt.Sprintf("%s exported %d connections to %s", SymbolSuccess, len(visible), filename)
m.statusExpiry = time.Now().Add(3 * time.Second)
m.showExportModal = false
m.exportFilename = ""
m.exportFormat = ""
m.exportError = ""
return m, clearStatusAfter(3 * time.Second)
case "backspace":
if len(m.exportFilename) > 0 {
m.exportFilename = m.exportFilename[:len(m.exportFilename)-1]
}
m.exportError = ""
default:
// only accept valid filename characters
char := msg.String()
if len(char) == 1 && isValidFilenameChar(char[0]) {
m.exportFilename += char
m.exportError = ""
}
}
return m, nil
}
func isValidFilenameChar(c byte) bool {
// allow alphanumeric, dash, underscore, dot
return (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '_' || c == '.'
}
func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() { switch msg.String() {
case "esc", "enter", "q": case "esc", "enter", "q":
@@ -118,37 +201,52 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
case "t": case "t":
m.showTCP = !m.showTCP m.showTCP = !m.showTCP
m.clampCursor() m.clampCursor()
m.saveState()
case "u": case "u":
m.showUDP = !m.showUDP m.showUDP = !m.showUDP
m.clampCursor() m.clampCursor()
m.saveState()
case "l": case "l":
m.showListening = !m.showListening m.showListening = !m.showListening
m.clampCursor() m.clampCursor()
m.saveState()
case "e": case "e":
m.showEstablished = !m.showEstablished m.showEstablished = !m.showEstablished
m.clampCursor() m.clampCursor()
m.saveState()
case "o": case "o":
m.showOther = !m.showOther m.showOther = !m.showOther
m.clampCursor() m.clampCursor()
m.saveState()
case "a": case "a":
m.showTCP = true m.showTCP = true
m.showUDP = true m.showUDP = true
m.showListening = true m.showListening = true
m.showEstablished = true m.showEstablished = true
m.showOther = true m.showOther = true
m.saveState()
// sorting // sorting
case "s": case "s":
m.cycleSort() m.cycleSort()
m.saveState()
case "S": case "S":
m.sortReverse = !m.sortReverse m.sortReverse = !m.sortReverse
m.applySorting() m.applySorting()
m.saveState()
// search // search
case "/": case "/":
m.searchActive = true m.searchActive = true
m.searchQuery = "" m.searchQuery = ""
// export
case "x":
m.showExportModal = true
m.exportFilename = ""
m.exportFormat = "csv"
m.exportError = ""
// actions // actions
case "enter", " ": case "enter", " ":
visible := m.visibleConnections() visible := m.visibleConnections()
@@ -220,6 +318,7 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.statusMessage = "address resolution: off" m.statusMessage = "address resolution: off"
} }
m.statusExpiry = time.Now().Add(2 * time.Second) m.statusExpiry = time.Now().Add(2 * time.Second)
m.saveState()
return m, clearStatusAfter(2 * time.Second) return m, clearStatusAfter(2 * time.Second)
// toggle port resolution // toggle port resolution
@@ -231,6 +330,7 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.statusMessage = "port resolution: off" m.statusMessage = "port resolution: off"
} }
m.statusExpiry = time.Now().Add(2 * time.Second) m.statusExpiry = time.Now().Add(2 * time.Second)
m.saveState()
return m, clearStatusAfter(2 * time.Second) return m, clearStatusAfter(2 * time.Second)
} }
@@ -266,6 +366,8 @@ func (m *model) cycleSort() {
collector.SortByPID, collector.SortByPID,
collector.SortByState, collector.SortByState,
collector.SortByProto, collector.SortByProto,
collector.SortByRaddr,
collector.SortByRport,
} }
for i, f := range fields { for i, f := range fields {

View File

@@ -3,6 +3,7 @@ package tui
import ( import (
"fmt" "fmt"
"github.com/karol-broda/snitch/internal/collector" "github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/resolver"
"syscall" "syscall"
"time" "time"
@@ -35,11 +36,20 @@ func (m model) tick() tea.Cmd {
} }
func (m model) fetchData() tea.Cmd { func (m model) fetchData() tea.Cmd {
resolveAddrs := m.resolveAddrs
return func() tea.Msg { return func() tea.Msg {
conns, err := collector.GetConnections() conns, err := collector.GetConnections()
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
// pre-warm dns cache in parallel if resolution is enabled
if resolveAddrs {
addrs := make([]string, 0, len(conns)*2)
for _, c := range conns {
addrs = append(addrs, c.Laddr, c.Raddr)
}
resolver.ResolveAddrsParallel(addrs)
}
return dataMsg{connections: conns} return dataMsg{connections: conns}
} }
} }

View File

@@ -2,11 +2,16 @@ package tui
import ( import (
"fmt" "fmt"
"github.com/karol-broda/snitch/internal/collector" "os"
"github.com/karol-broda/snitch/internal/theme" "strconv"
"strings"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/state"
"github.com/karol-broda/snitch/internal/theme"
) )
type model struct { type model struct {
@@ -51,19 +56,30 @@ type model struct {
// status message (temporary feedback) // status message (temporary feedback)
statusMessage string statusMessage string
statusExpiry time.Time statusExpiry time.Time
// export modal
showExportModal bool
exportFilename string
exportFormat string // "csv" or "tsv"
exportError string
// state persistence
rememberState bool
} }
type Options struct { type Options struct {
Theme string Theme string
Interval time.Duration Interval time.Duration
TCP bool TCP bool
UDP bool UDP bool
Listening bool Listening bool
Established bool Established bool
Other bool Other bool
FilterSet bool // true if user specified any filter flags FilterSet bool // true if user specified any filter flags
ResolveAddrs bool // when true, resolve IP addresses to hostnames ResolveAddrs bool // when true, resolve IP addresses to hostnames
ResolvePorts bool // when true, resolve port numbers to service names ResolvePorts bool // when true, resolve port numbers to service names
NoCache bool // when true, disable DNS caching
RememberState bool // when true, persist view options between sessions
} }
func New(opts Options) model { func New(opts Options) model {
@@ -78,8 +94,27 @@ func New(opts Options) model {
showListening := true showListening := true
showEstablished := true showEstablished := true
showOther := true showOther := true
sortField := collector.SortByLport
sortReverse := false
resolveAddrs := opts.ResolveAddrs
resolvePorts := opts.ResolvePorts
// if user specified filters, use those instead // load saved state if enabled and no CLI filter flags were specified
if opts.RememberState && !opts.FilterSet {
if saved := state.Load(); saved != nil {
showTCP = saved.ShowTCP
showUDP = saved.ShowUDP
showListening = saved.ShowListening
showEstablished = saved.ShowEstablished
showOther = saved.ShowOther
sortField = saved.SortField
sortReverse = saved.SortReverse
resolveAddrs = saved.ResolveAddrs
resolvePorts = saved.ResolvePorts
}
}
// if user specified filters, use those instead (CLI flags take precedence)
if opts.FilterSet { if opts.FilterSet {
showTCP = opts.TCP showTCP = opts.TCP
showUDP = opts.UDP showUDP = opts.UDP
@@ -107,13 +142,15 @@ func New(opts Options) model {
showListening: showListening, showListening: showListening,
showEstablished: showEstablished, showEstablished: showEstablished,
showOther: showOther, showOther: showOther,
sortField: collector.SortByLport, sortField: sortField,
resolveAddrs: opts.ResolveAddrs, sortReverse: sortReverse,
resolvePorts: opts.ResolvePorts, resolveAddrs: resolveAddrs,
resolvePorts: resolvePorts,
theme: theme.GetTheme(opts.Theme), theme: theme.GetTheme(opts.Theme),
interval: interval, interval: interval,
lastRefresh: time.Now(), lastRefresh: time.Now(),
watchedPIDs: make(map[int]bool), watchedPIDs: make(map[int]bool),
rememberState: opts.RememberState,
} }
} }
@@ -186,6 +223,11 @@ func (m model) View() string {
return m.overlayModal(main, m.renderKillModal()) return m.overlayModal(main, m.renderKillModal())
} }
// overlay export modal on top of main view
if m.showExportModal {
return m.overlayModal(main, m.renderExportModal())
}
return main return main
} }
@@ -261,12 +303,19 @@ func (m model) matchesFilters(c collector.Connection) bool {
} }
func (m model) matchesSearch(c collector.Connection) bool { func (m model) matchesSearch(c collector.Connection) bool {
lportStr := strconv.Itoa(c.Lport)
rportStr := strconv.Itoa(c.Rport)
pidStr := strconv.Itoa(c.PID)
return containsIgnoreCase(c.Process, m.searchQuery) || return containsIgnoreCase(c.Process, m.searchQuery) ||
containsIgnoreCase(c.Laddr, m.searchQuery) || containsIgnoreCase(c.Laddr, m.searchQuery) ||
containsIgnoreCase(c.Raddr, m.searchQuery) || containsIgnoreCase(c.Raddr, m.searchQuery) ||
containsIgnoreCase(c.User, m.searchQuery) || containsIgnoreCase(c.User, m.searchQuery) ||
containsIgnoreCase(c.Proto, m.searchQuery) || containsIgnoreCase(c.Proto, m.searchQuery) ||
containsIgnoreCase(c.State, m.searchQuery) containsIgnoreCase(c.State, m.searchQuery) ||
containsIgnoreCase(lportStr, m.searchQuery) ||
containsIgnoreCase(rportStr, m.searchQuery) ||
containsIgnoreCase(pidStr, m.searchQuery)
} }
func (m model) isWatched(pid int) bool { func (m model) isWatched(pid int) bool {
@@ -290,3 +339,84 @@ func (m *model) toggleWatch(pid int) {
func (m model) watchedCount() int { func (m model) watchedCount() int {
return len(m.watchedPIDs) return len(m.watchedPIDs)
} }
// currentState returns the current view options as a TUIState for persistence
func (m model) currentState() state.TUIState {
return state.TUIState{
ShowTCP: m.showTCP,
ShowUDP: m.showUDP,
ShowListening: m.showListening,
ShowEstablished: m.showEstablished,
ShowOther: m.showOther,
SortField: m.sortField,
SortReverse: m.sortReverse,
ResolveAddrs: m.resolveAddrs,
ResolvePorts: m.resolvePorts,
}
}
// saveState persists current view options in the background
func (m model) saveState() {
if m.rememberState {
state.SaveAsync(m.currentState())
}
}
// exportConnections writes visible connections to a file in csv or tsv format
func (m model) exportConnections() error {
visible := m.visibleConnections()
if len(visible) == 0 {
return fmt.Errorf("no connections to export")
}
file, err := os.Create(m.exportFilename)
if err != nil {
return err
}
defer func() { _ = file.Close() }()
// determine delimiter from format selection or filename
delimiter := ","
if m.exportFormat == "tsv" || strings.HasSuffix(strings.ToLower(m.exportFilename), ".tsv") {
delimiter = "\t"
}
header := []string{"PID", "PROCESS", "USER", "PROTO", "STATE", "LADDR", "LPORT", "RADDR", "RPORT"}
_, err = file.WriteString(strings.Join(header, delimiter) + "\n")
if err != nil {
return err
}
for _, c := range visible {
// escape fields that might contain delimiter
process := escapeField(c.Process, delimiter)
user := escapeField(c.User, delimiter)
row := []string{
strconv.Itoa(c.PID),
process,
user,
c.Proto,
c.State,
c.Laddr,
strconv.Itoa(c.Lport),
c.Raddr,
strconv.Itoa(c.Rport),
}
_, err = file.WriteString(strings.Join(row, delimiter) + "\n")
if err != nil {
return err
}
}
return nil
}
// escapeField quotes a field if it contains the delimiter or quotes
func escapeField(s, delimiter string) string {
if strings.Contains(s, delimiter) || strings.Contains(s, "\"") || strings.Contains(s, "\n") {
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
}
return s
}

View File

@@ -1,12 +1,15 @@
package tui package tui
import ( import (
"github.com/karol-broda/snitch/internal/collector" "os"
"path/filepath"
"strings"
"testing" "testing"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/x/exp/teatest" "github.com/charmbracelet/x/exp/teatest"
"github.com/karol-broda/snitch/internal/collector"
) )
func TestTUI_InitialState(t *testing.T) { func TestTUI_InitialState(t *testing.T) {
@@ -430,3 +433,346 @@ func TestTUI_FormatRemoteHelper(t *testing.T) {
} }
} }
func TestTUI_MatchesSearchPort(t *testing.T) {
m := New(Options{Theme: "dark"})
tests := []struct {
name string
searchQuery string
conn collector.Connection
expected bool
}{
{
name: "matches local port",
searchQuery: "3000",
conn: collector.Connection{Lport: 3000},
expected: true,
},
{
name: "matches remote port",
searchQuery: "443",
conn: collector.Connection{Rport: 443},
expected: true,
},
{
name: "matches pid",
searchQuery: "1234",
conn: collector.Connection{PID: 1234},
expected: true,
},
{
name: "partial port match",
searchQuery: "80",
conn: collector.Connection{Lport: 8080},
expected: true,
},
{
name: "no match",
searchQuery: "9999",
conn: collector.Connection{Lport: 80, Rport: 443, PID: 1234},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m.searchQuery = tc.searchQuery
result := m.matchesSearch(tc.conn)
if result != tc.expected {
t.Errorf("matchesSearch() = %v, want %v", result, tc.expected)
}
})
}
}
func TestTUI_SortCycleIncludesRemote(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
// start at default (Lport)
if m.sortField != collector.SortByLport {
t.Fatalf("expected initial sort field to be lport, got %v", m.sortField)
}
// cycle through all fields and verify raddr and rport are included
foundRaddr := false
foundRport := false
seenFields := make(map[collector.SortField]bool)
for i := 0; i < 10; i++ {
m.cycleSort()
seenFields[m.sortField] = true
if m.sortField == collector.SortByRaddr {
foundRaddr = true
}
if m.sortField == collector.SortByRport {
foundRport = true
}
if foundRaddr && foundRport {
break
}
}
if !foundRaddr {
t.Error("expected sort cycle to include SortByRaddr")
}
if !foundRport {
t.Error("expected sort cycle to include SortByRport")
}
}
func TestTUI_ExportModal(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// initially export modal should not be shown
if m.showExportModal {
t.Fatal("expected showExportModal to be false initially")
}
// press 'x' to open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
if !m.showExportModal {
t.Error("expected showExportModal to be true after pressing 'x'")
}
// type filename
for _, c := range "test.csv" {
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{c}})
m = newModel.(model)
}
if m.exportFilename != "test.csv" {
t.Errorf("expected exportFilename to be 'test.csv', got '%s'", m.exportFilename)
}
// escape should close modal
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyEsc})
m = newModel.(model)
if m.showExportModal {
t.Error("expected showExportModal to be false after escape")
}
if m.exportFilename != "" {
t.Error("expected exportFilename to be cleared after escape")
}
}
func TestTUI_ExportModalDefaultFilename(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// add test data
m.connections = []collector.Connection{
{PID: 1234, Process: "nginx", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80},
}
// open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
// render export modal should show default filename hint
view := m.View()
if view == "" {
t.Error("expected non-empty view with export modal")
}
}
func TestTUI_ExportModalBackspace(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
// type filename
for _, c := range "test.csv" {
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{c}})
m = newModel.(model)
}
// backspace should remove last character
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyBackspace})
m = newModel.(model)
if m.exportFilename != "test.cs" {
t.Errorf("expected 'test.cs' after backspace, got '%s'", m.exportFilename)
}
}
func TestTUI_ExportConnectionsCSV(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.connections = []collector.Connection{
{PID: 1234, Process: "nginx", User: "www-data", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80, Raddr: "*", Rport: 0},
{PID: 5678, Process: "node", User: "node", Proto: "tcp", State: "ESTABLISHED", Laddr: "192.168.1.1", Lport: 3000, Raddr: "10.0.0.1", Rport: 443},
}
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "test_export.csv")
m.exportFilename = csvPath
err := m.exportConnections()
if err != nil {
t.Fatalf("exportConnections() failed: %v", err)
}
content, err := os.ReadFile(csvPath)
if err != nil {
t.Fatalf("failed to read exported file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
if len(lines) != 3 {
t.Errorf("expected 3 lines (header + 2 data), got %d", len(lines))
}
if !strings.Contains(lines[0], "PID") || !strings.Contains(lines[0], "PROCESS") {
t.Error("header line should contain PID and PROCESS")
}
if !strings.Contains(lines[1], "nginx") || !strings.Contains(lines[1], "1234") {
t.Error("first data line should contain nginx and 1234")
}
if !strings.Contains(lines[2], "node") || !strings.Contains(lines[2], "5678") {
t.Error("second data line should contain node and 5678")
}
}
func TestTUI_ExportConnectionsTSV(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.connections = []collector.Connection{
{PID: 1234, Process: "nginx", User: "www-data", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80, Raddr: "*", Rport: 0},
}
tmpDir := t.TempDir()
tsvPath := filepath.Join(tmpDir, "test_export.tsv")
m.exportFilename = tsvPath
err := m.exportConnections()
if err != nil {
t.Fatalf("exportConnections() failed: %v", err)
}
content, err := os.ReadFile(tsvPath)
if err != nil {
t.Fatalf("failed to read exported file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
// TSV should use tabs
if !strings.Contains(lines[0], "\t") {
t.Error("TSV file should use tabs as delimiters")
}
// CSV delimiter should not be present between fields
fields := strings.Split(lines[1], "\t")
if len(fields) < 9 {
t.Errorf("expected at least 9 tab-separated fields, got %d", len(fields))
}
}
func TestTUI_ExportWithFilters(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.showTCP = true
m.showUDP = false
m.connections = []collector.Connection{
{PID: 1, Process: "tcp_proc", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80},
{PID: 2, Process: "udp_proc", Proto: "udp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 53},
}
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "filtered_export.csv")
m.exportFilename = csvPath
err := m.exportConnections()
if err != nil {
t.Fatalf("exportConnections() failed: %v", err)
}
content, err := os.ReadFile(csvPath)
if err != nil {
t.Fatalf("failed to read exported file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
// should only have header + 1 TCP connection (UDP filtered out)
if len(lines) != 2 {
t.Errorf("expected 2 lines (header + 1 TCP), got %d", len(lines))
}
if strings.Contains(string(content), "udp_proc") {
t.Error("UDP connection should not be exported when UDP filter is off")
}
}
func TestTUI_ExportFormatToggle(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
// default format should be csv
if m.exportFormat != "csv" {
t.Errorf("expected default format 'csv', got '%s'", m.exportFormat)
}
// tab should toggle to tsv
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = newModel.(model)
if m.exportFormat != "tsv" {
t.Errorf("expected format 'tsv' after tab, got '%s'", m.exportFormat)
}
// tab again should toggle back to csv
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = newModel.(model)
if m.exportFormat != "csv" {
t.Errorf("expected format 'csv' after second tab, got '%s'", m.exportFormat)
}
}
func TestTUI_ExportModalRenderWithStats(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
m.connections = []collector.Connection{
{PID: 1, Process: "nginx", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80},
{PID: 2, Process: "postgres", Proto: "tcp", State: "LISTEN", Laddr: "127.0.0.1", Lport: 5432},
{PID: 3, Process: "node", Proto: "tcp", State: "ESTABLISHED", Laddr: "192.168.1.1", Lport: 3000},
}
m.showExportModal = true
m.exportFormat = "csv"
view := m.View()
// modal should contain summary info
if !strings.Contains(view, "3") {
t.Error("modal should show connection count")
}
// modal should show format options
if !strings.Contains(view, "CSV") || !strings.Contains(view, "TSV") {
t.Error("modal should show format options")
}
}

View File

@@ -33,6 +33,8 @@ const (
BoxCross = string('\u253C') // light vertical and horizontal BoxCross = string('\u253C') // light vertical and horizontal
// misc // misc
SymbolDash = string('\u2013') // en dash SymbolDash = string('\u2013') // en dash
SymbolExport = string('\u21E5') // rightwards arrow to bar
SymbolPrompt = string('\u276F') // heavy right-pointing angle quotation mark ornament
) )

View File

@@ -203,7 +203,7 @@ func (m model) renderStatusLine() string {
return " " + m.theme.Styles.Warning.Render(m.statusMessage) return " " + m.theme.Styles.Warning.Render(m.statusMessage)
} }
left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state n/N dns w watch K kill s sort / search ? help q quit") left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state n/N dns w watch K kill s sort / search x export ? help q quit")
// show watched count if any // show watched count if any
if m.watchedCount() > 0 { if m.watchedCount() > 0 {
@@ -271,6 +271,7 @@ func (m model) renderHelp() string {
other other
───── ─────
/ search / search
x export to csv/tsv (enter filename)
r refresh now r refresh now
q quit q quit
@@ -301,6 +302,8 @@ func (m model) renderDetail() string {
value string value string
}{ }{
{"process", c.Process}, {"process", c.Process},
{"cmdline", c.Cmdline},
{"cwd", c.Cwd},
{"pid", fmt.Sprintf("%d", c.PID)}, {"pid", fmt.Sprintf("%d", c.PID)},
{"user", c.User}, {"user", c.User},
{"protocol", c.Proto}, {"protocol", c.Proto},
@@ -368,6 +371,119 @@ func (m model) renderKillModal() string {
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
} }
func (m model) renderExportModal() string {
visible := m.visibleConnections()
// count protocols and states for preview
tcpCount, udpCount := 0, 0
listenCount, estabCount, otherCount := 0, 0, 0
for _, c := range visible {
if c.Proto == "tcp" || c.Proto == "tcp6" {
tcpCount++
} else {
udpCount++
}
switch c.State {
case "LISTEN":
listenCount++
case "ESTABLISHED":
estabCount++
default:
otherCount++
}
}
var lines []string
// header
lines = append(lines, "")
headerText := " " + SymbolExport + " EXPORT CONNECTIONS "
lines = append(lines, m.theme.Styles.Header.Render(headerText))
lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 36)))
lines = append(lines, "")
// stats preview section
lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" summary"))
lines = append(lines, fmt.Sprintf(" total: %s",
m.theme.Styles.Success.Render(fmt.Sprintf("%d connections", len(visible)))))
protoSummary := fmt.Sprintf(" proto: %s tcp %s udp",
m.theme.Styles.GetProtoStyle("tcp").Render(fmt.Sprintf("%d", tcpCount)),
m.theme.Styles.GetProtoStyle("udp").Render(fmt.Sprintf("%d", udpCount)))
lines = append(lines, protoSummary)
stateSummary := fmt.Sprintf(" state: %s listen %s estab %s other",
m.theme.Styles.GetStateStyle("LISTEN").Render(fmt.Sprintf("%d", listenCount)),
m.theme.Styles.GetStateStyle("ESTABLISHED").Render(fmt.Sprintf("%d", estabCount)),
m.theme.Styles.Normal.Render(fmt.Sprintf("%d", otherCount)))
lines = append(lines, stateSummary)
lines = append(lines, "")
// format selection
lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" format"))
csvStyle := m.theme.Styles.Normal
tsvStyle := m.theme.Styles.Normal
csvIndicator := " "
tsvIndicator := " "
if m.exportFormat == "tsv" {
tsvStyle = m.theme.Styles.Success
tsvIndicator = m.theme.Styles.Success.Render(SymbolSelected + " ")
} else {
csvStyle = m.theme.Styles.Success
csvIndicator = m.theme.Styles.Success.Render(SymbolSelected + " ")
}
formatLine := fmt.Sprintf(" %s%s %s%s",
csvIndicator, csvStyle.Render("CSV (comma)"),
tsvIndicator, tsvStyle.Render("TSV (tab)"))
lines = append(lines, formatLine)
lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 8)+" press "+m.theme.Styles.Warning.Render("tab")+" to toggle"))
lines = append(lines, "")
// filename input
lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" filename"))
ext := ".csv"
if m.exportFormat == "tsv" {
ext = ".tsv"
}
filenameDisplay := m.exportFilename
if filenameDisplay == "" {
filenameDisplay = "connections"
}
inputBox := fmt.Sprintf(" %s %s%s",
m.theme.Styles.Success.Render(SymbolPrompt),
m.theme.Styles.Warning.Render(filenameDisplay),
m.theme.Styles.Success.Render(ext+"▌"))
lines = append(lines, inputBox)
lines = append(lines, "")
// error display
if m.exportError != "" {
lines = append(lines, m.theme.Styles.Error.Render(fmt.Sprintf(" %s %s", SymbolWarning, m.exportError)))
lines = append(lines, "")
}
// preview of fields
lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 36)))
fieldsPreview := " fields: PID, PROCESS, USER, PROTO, STATE, LADDR, LPORT, RADDR, RPORT"
lines = append(lines, m.theme.Styles.Normal.Render(truncate(fieldsPreview, 40)))
lines = append(lines, "")
// action buttons
lines = append(lines, fmt.Sprintf(" %s export %s toggle format %s cancel",
m.theme.Styles.Success.Render("[enter]"),
m.theme.Styles.Warning.Render("[tab]"),
m.theme.Styles.Error.Render("[esc]")))
lines = append(lines, "")
return strings.Join(lines, "\n")
}
func (m model) overlayModal(background, modal string) string { func (m model) overlayModal(background, modal string) string {
bgLines := strings.Split(background, "\n") bgLines := strings.Split(background, "\n")
modalLines := strings.Split(modal, "\n") modalLines := strings.Split(modal, "\n")

177
nix/hm-module.nix Normal file
View File

@@ -0,0 +1,177 @@
{
config,
lib,
pkgs,
...
}:
let
cfg = config.programs.snitch;
themes = [
"ansi"
"catppuccin-mocha"
"catppuccin-macchiato"
"catppuccin-frappe"
"catppuccin-latte"
"gruvbox-dark"
"gruvbox-light"
"dracula"
"nord"
"tokyo-night"
"tokyo-night-storm"
"tokyo-night-light"
"solarized-dark"
"solarized-light"
"one-dark"
"mono"
"auto"
];
defaultFields = [
"pid"
"process"
"user"
"proto"
"state"
"laddr"
"lport"
"raddr"
"rport"
];
tomlFormat = pkgs.formats.toml { };
settingsType = lib.types.submodule {
freeformType = tomlFormat.type;
options = {
defaults = lib.mkOption {
type = lib.types.submodule {
freeformType = tomlFormat.type;
options = {
interval = lib.mkOption {
type = lib.types.str;
default = "1s";
example = "2s";
description = "Default refresh interval for watch/stats/trace commands.";
};
numeric = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Disable name/service resolution by default.";
};
fields = lib.mkOption {
type = lib.types.listOf lib.types.str;
default = defaultFields;
example = [ "pid" "process" "proto" "state" "laddr" "lport" ];
description = "Default fields to display.";
};
theme = lib.mkOption {
type = lib.types.enum themes;
default = "ansi";
description = ''
Color theme for the TUI. "ansi" inherits terminal colors.
'';
};
units = lib.mkOption {
type = lib.types.enum [ "auto" "si" "iec" ];
default = "auto";
description = "Default units for byte display.";
};
color = lib.mkOption {
type = lib.types.enum [ "auto" "always" "never" ];
default = "auto";
description = "Default color mode.";
};
resolve = lib.mkOption {
type = lib.types.bool;
default = true;
description = "Enable name resolution by default.";
};
dns_cache = lib.mkOption {
type = lib.types.bool;
default = true;
description = "Enable DNS caching.";
};
ipv4 = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Filter to IPv4 only by default.";
};
ipv6 = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Filter to IPv6 only by default.";
};
no_headers = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Omit headers in output by default.";
};
output_format = lib.mkOption {
type = lib.types.enum [ "table" "json" "csv" ];
default = "table";
description = "Default output format.";
};
sort_by = lib.mkOption {
type = lib.types.str;
default = "";
example = "pid";
description = "Default sort field.";
};
};
};
default = { };
description = "Default settings for snitch commands.";
};
};
};
in
{
options.programs.snitch = {
enable = lib.mkEnableOption "snitch, a friendlier ss/netstat for humans";
package = lib.mkPackageOption pkgs "snitch" { };
settings = lib.mkOption {
type = settingsType;
default = { };
example = lib.literalExpression ''
{
defaults = {
theme = "catppuccin-mocha";
interval = "2s";
resolve = true;
};
}
'';
description = ''
Configuration written to {file}`$XDG_CONFIG_HOME/snitch/snitch.toml`.
See <https://github.com/karol-broda/snitch> for available options.
'';
};
};
config = lib.mkIf cfg.enable {
home.packages = [ cfg.package ];
xdg.configFile."snitch/snitch.toml" = lib.mkIf (cfg.settings != { }) {
source = tomlFormat.generate "snitch.toml" cfg.settings;
};
};
}

View File

@@ -0,0 +1,429 @@
# home manager module tests
#
# run with: nix build .#checks.x86_64-linux.hm-module
#
# tests cover:
# - module evaluation with various configurations
# - type validation for all options
# - generated TOML content verification
# - edge cases (disabled, empty settings, full settings)
{ pkgs, lib, hmModule }:
let
# minimal home-manager stub for standalone module testing
hmLib = {
hm.types.dagOf = lib.types.attrsOf;
dag.entryAnywhere = x: x;
};
# evaluate the hm module with a given config
evalModule = testConfig:
lib.evalModules {
modules = [
hmModule
# stub home-manager's expected structure
{
options = {
home.packages = lib.mkOption {
type = lib.types.listOf lib.types.package;
default = [ ];
};
xdg.configFile = lib.mkOption {
type = lib.types.attrsOf (lib.types.submodule {
options = {
source = lib.mkOption { type = lib.types.path; };
text = lib.mkOption { type = lib.types.str; default = ""; };
};
});
default = { };
};
};
}
testConfig
];
specialArgs = { inherit pkgs lib; };
};
# read generated TOML file content
readGeneratedToml = evalResult:
let
configFile = evalResult.config.xdg.configFile."snitch/snitch.toml" or null;
in
if configFile != null && configFile ? source
then builtins.readFile configFile.source
else null;
# test cases
tests = {
# test 1: module evaluates when disabled
moduleDisabled = {
name = "module-disabled";
config = {
programs.snitch.enable = false;
};
assertions = evalResult: [
{
assertion = evalResult.config.home.packages == [ ];
message = "packages should be empty when disabled";
}
{
assertion = !(evalResult.config.xdg.configFile ? "snitch/snitch.toml");
message = "config file should not exist when disabled";
}
];
};
# test 2: module evaluates with enable only (defaults)
moduleEnabledDefaults = {
name = "module-enabled-defaults";
config = {
programs.snitch.enable = true;
};
assertions = evalResult: [
{
assertion = builtins.length evalResult.config.home.packages == 1;
message = "package should be installed when enabled";
}
];
};
# test 3: all theme values are valid
themeValidation = {
name = "theme-validation";
config = {
programs.snitch = {
enable = true;
settings.defaults.theme = "catppuccin-mocha";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = toml != null;
message = "TOML config should be generated";
}
{
assertion = lib.hasInfix "catppuccin-mocha" toml;
message = "theme should be set in TOML";
}
];
};
# test 4: full configuration with all options
fullConfiguration = {
name = "full-configuration";
config = {
programs.snitch = {
enable = true;
settings.defaults = {
interval = "2s";
numeric = true;
fields = [ "pid" "process" "proto" ];
theme = "nord";
units = "si";
color = "always";
resolve = false;
dns_cache = false;
ipv4 = true;
ipv6 = false;
no_headers = true;
output_format = "json";
sort_by = "pid";
};
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = toml != null;
message = "TOML config should be generated";
}
{
assertion = lib.hasInfix "interval = \"2s\"" toml;
message = "interval should be 2s";
}
{
assertion = lib.hasInfix "numeric = true" toml;
message = "numeric should be true";
}
{
assertion = lib.hasInfix "theme = \"nord\"" toml;
message = "theme should be nord";
}
{
assertion = lib.hasInfix "units = \"si\"" toml;
message = "units should be si";
}
{
assertion = lib.hasInfix "color = \"always\"" toml;
message = "color should be always";
}
{
assertion = lib.hasInfix "resolve = false" toml;
message = "resolve should be false";
}
{
assertion = lib.hasInfix "output_format = \"json\"" toml;
message = "output_format should be json";
}
{
assertion = lib.hasInfix "sort_by = \"pid\"" toml;
message = "sort_by should be pid";
}
];
};
# test 5: output format enum validation
outputFormatCsv = {
name = "output-format-csv";
config = {
programs.snitch = {
enable = true;
settings.defaults.output_format = "csv";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "output_format = \"csv\"" toml;
message = "output_format should accept csv";
}
];
};
# test 6: units enum validation
unitsIec = {
name = "units-iec";
config = {
programs.snitch = {
enable = true;
settings.defaults.units = "iec";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "units = \"iec\"" toml;
message = "units should accept iec";
}
];
};
# test 7: color never value
colorNever = {
name = "color-never";
config = {
programs.snitch = {
enable = true;
settings.defaults.color = "never";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "color = \"never\"" toml;
message = "color should accept never";
}
];
};
# test 8: freeform type allows custom keys
freeformCustomKeys = {
name = "freeform-custom-keys";
config = {
programs.snitch = {
enable = true;
settings = {
defaults.theme = "dracula";
custom_section = {
custom_key = "custom_value";
};
};
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "custom_key" toml;
message = "freeform type should allow custom keys";
}
];
};
# test 9: all themes evaluate correctly
allThemes =
let
themes = [
"ansi"
"catppuccin-mocha"
"catppuccin-macchiato"
"catppuccin-frappe"
"catppuccin-latte"
"gruvbox-dark"
"gruvbox-light"
"dracula"
"nord"
"tokyo-night"
"tokyo-night-storm"
"tokyo-night-light"
"solarized-dark"
"solarized-light"
"one-dark"
"mono"
"auto"
];
in
{
name = "all-themes";
# use the last theme as the test config
config = {
programs.snitch = {
enable = true;
settings.defaults.theme = "auto";
};
};
assertions = evalResult:
let
# verify all themes can be set by evaluating them
themeResults = map
(theme:
let
result = evalModule {
programs.snitch = {
enable = true;
settings.defaults.theme = theme;
};
};
toml = readGeneratedToml result;
in
{
inherit theme;
success = toml != null && lib.hasInfix theme toml;
}
)
themes;
allSucceeded = lib.all (r: r.success) themeResults;
in
[
{
assertion = allSucceeded;
message = "all themes should evaluate correctly: ${
lib.concatMapStringsSep ", "
(r: "${r.theme}=${if r.success then "ok" else "fail"}")
themeResults
}";
}
];
};
# test 10: fields list serialization
fieldsListSerialization = {
name = "fields-list-serialization";
config = {
programs.snitch = {
enable = true;
settings.defaults.fields = [ "pid" "process" "proto" "state" ];
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "pid" toml && lib.hasInfix "process" toml;
message = "fields list should be serialized correctly";
}
];
};
};
# run all tests and collect results
runTests =
let
testResults = lib.mapAttrsToList
(name: test:
let
evalResult = evalModule test.config;
assertions = test.assertions evalResult;
failures = lib.filter (a: !a.assertion) assertions;
in
{
inherit name;
testName = test.name;
passed = failures == [ ];
failures = map (f: f.message) failures;
}
)
tests;
allPassed = lib.all (r: r.passed) testResults;
failedTests = lib.filter (r: !r.passed) testResults;
summary = ''
========================================
home manager module test results
========================================
total tests: ${toString (builtins.length testResults)}
passed: ${toString (builtins.length (lib.filter (r: r.passed) testResults))}
failed: ${toString (builtins.length failedTests)}
========================================
${lib.concatMapStringsSep "\n" (r:
if r.passed
then "[yes] ${r.testName}"
else "[no] ${r.testName}\n ${lib.concatStringsSep "\n " r.failures}"
) testResults}
========================================
'';
in
{
inherit testResults allPassed failedTests summary;
};
results = runTests;
in
pkgs.runCommand "hm-module-test"
{
passthru = {
inherit results;
# expose for debugging
inherit evalModule tests;
};
}
(
if results.allPassed
then ''
echo "${results.summary}"
echo "all tests passed"
touch $out
''
else ''
echo "${results.summary}"
echo ""
echo "failed tests:"
${lib.concatMapStringsSep "\n" (t: ''
echo " - ${t.testName}: ${lib.concatStringsSep ", " t.failures}"
'') results.failedTests}
exit 1
''
)