From 49948de0ed421661b23be3369ec2b1acb0d92c38 Mon Sep 17 00:00:00 2001 From: Karol Broda <122811026+karol-broda@users.noreply.github.com> Date: Wed, 24 Dec 2025 11:12:39 +0100 Subject: [PATCH] fix: dns resolution taking long and add caching options (#8) --- README.md | 28 +- cmd/ls.go | 7 +- cmd/root.go | 5 +- cmd/runtime.go | 47 +- cmd/runtime_test.go | 525 +++++++++++++++++++++++ cmd/top.go | 21 +- cmd/trace.go | 29 +- internal/config/config.go | 11 +- internal/resolver/resolver.go | 139 ++++-- internal/resolver/resolver_bench_test.go | 159 +++++++ internal/resolver/resolver_test.go | 387 +++++++++++++++++ internal/tui/messages.go | 10 + internal/tui/model.go | 1 + 13 files changed, 1302 insertions(+), 67 deletions(-) create mode 100644 cmd/runtime_test.go create mode 100644 internal/resolver/resolver_bench_test.go create mode 100644 internal/resolver/resolver_test.go diff --git a/README.md b/README.md index 9db0197..f974ea8 100644 --- a/README.md +++ b/README.md @@ -167,9 +167,20 @@ shortcut flags work on all commands: -e, --established established connections -4, --ipv4 ipv4 only -6, --ipv6 ipv6 only --n, --numeric no dns resolution ``` +## resolution + +dns and service name resolution options: + +``` +--resolve-addrs resolve ip addresses to hostnames (default: true) +--resolve-ports resolve port numbers to service names +--no-cache disable dns caching (force fresh lookups) +``` + +dns lookups are performed in parallel and cached for performance. use `--no-cache` to bypass the cache for debugging or when addresses change frequently. + for more specific filtering, use `key=value` syntax with `ls`: ```bash @@ -208,8 +219,19 @@ optional config file at `~/.config/snitch/snitch.toml`: ```toml [defaults] -numeric = false -theme = "auto" +numeric = false # disable name resolution +dns_cache = true # cache dns lookups (set to false to disable) +theme = "auto" # color theme: auto, dark, light, mono +``` + +### environment variables + +```bash +SNITCH_THEME=dark # set default theme +SNITCH_RESOLVE=0 # disable dns resolution +SNITCH_DNS_CACHE=0 # disable dns caching +SNITCH_NO_COLOR=1 # disable color output +SNITCH_CONFIG=/path/to # custom config file path ``` ## requirements diff --git a/cmd/ls.go b/cmd/ls.go index 93cd7ad..e40c98d 100644 --- a/cmd/ls.go +++ b/cmd/ls.go @@ -30,8 +30,6 @@ var ( sortBy string fields string colorMode string - resolveAddrs bool - resolvePorts bool plainOutput bool ) @@ -400,10 +398,9 @@ func init() { lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)") lsCmd.Flags().StringVarP(&fields, "fields", "f", strings.Join(cfg.Defaults.Fields, ","), "Comma-separated list of fields to show") lsCmd.Flags().StringVar(&colorMode, "color", cfg.Defaults.Color, "Color mode (auto, always, never)") - lsCmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames") - lsCmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names") lsCmd.Flags().BoolVarP(&plainOutput, "plain", "p", false, "Plain output (parsable, no styling)") - // shared filter flags + // shared flags addFilterFlags(lsCmd) + addResolutionFlags(lsCmd) } \ No newline at end of file diff --git a/cmd/root.go b/cmd/root.go index 90938dc..72d631d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -45,9 +45,8 @@ func init() { cfg := config.Get() rootCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')") rootCmd.Flags().DurationVarP(&topInterval, "interval", "i", 0, "Refresh interval (default 1s)") - rootCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames") - rootCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names") - // shared filter flags for root command + // shared flags for root command addFilterFlags(rootCmd) + addResolutionFlags(rootCmd) } \ No newline at end of file diff --git a/cmd/runtime.go b/cmd/runtime.go index cf757e5..597d5d1 100644 --- a/cmd/runtime.go +++ b/cmd/runtime.go @@ -4,6 +4,8 @@ import ( "fmt" "github.com/karol-broda/snitch/internal/collector" "github.com/karol-broda/snitch/internal/color" + "github.com/karol-broda/snitch/internal/config" + "github.com/karol-broda/snitch/internal/resolver" "strconv" "strings" @@ -11,7 +13,7 @@ import ( ) // Runtime holds the shared state for all commands. -// it handles common filter logic, fetching, and filtering connections. +// it handles common filter logic, fetching, filtering, and resolution. type Runtime struct { // filter options built from flags and args Filters collector.FilterOptions @@ -23,6 +25,7 @@ type Runtime struct { ColorMode string ResolveAddrs bool ResolvePorts bool + NoCache bool } // shared filter flags - used by all commands @@ -35,6 +38,13 @@ var ( filterIPv6 bool ) +// shared resolution flags - used by all commands +var ( + resolveAddrs bool + resolvePorts bool + noCache bool +) + // BuildFilters constructs FilterOptions from command args and shortcut flags. func BuildFilters(args []string) (collector.FilterOptions, error) { filters, err := ParseFilterArgs(args) @@ -77,6 +87,12 @@ func FetchConnections(filters collector.FilterOptions) ([]collector.Connection, func NewRuntime(args []string, colorMode string) (*Runtime, error) { color.Init(colorMode) + cfg := config.Get() + + // configure resolver with cache setting (flag overrides config) + effectiveNoCache := noCache || !cfg.Defaults.DNSCache + resolver.SetNoCache(effectiveNoCache) + filters, err := BuildFilters(args) if err != nil { return nil, fmt.Errorf("failed to parse filters: %w", err) @@ -87,13 +103,30 @@ func NewRuntime(args []string, colorMode string) (*Runtime, error) { return nil, fmt.Errorf("failed to fetch connections: %w", err) } - return &Runtime{ + rt := &Runtime{ Filters: filters, Connections: connections, ColorMode: colorMode, ResolveAddrs: resolveAddrs, ResolvePorts: resolvePorts, - }, nil + NoCache: effectiveNoCache, + } + + // pre-warm dns cache by resolving all addresses in parallel + if resolveAddrs { + rt.PreWarmDNS() + } + + return rt, nil +} + +// PreWarmDNS resolves all connection addresses in parallel to warm the cache. +func (r *Runtime) PreWarmDNS() { + addrs := make([]string, 0, len(r.Connections)*2) + for _, c := range r.Connections { + addrs = append(addrs, c.Laddr, c.Raddr) + } + resolver.ResolveAddrsParallel(addrs) } // SortConnections sorts the runtime's connections in place. @@ -201,3 +234,11 @@ func addFilterFlags(cmd *cobra.Command) { cmd.Flags().BoolVarP(&filterIPv6, "ipv6", "6", false, "Only show IPv6 connections") } +// addResolutionFlags adds the common resolution flags to a command. +func addResolutionFlags(cmd *cobra.Command) { + cfg := config.Get() + cmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames") + cmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names") + cmd.Flags().BoolVar(&noCache, "no-cache", !cfg.Defaults.DNSCache, "Disable DNS caching (force fresh lookups)") +} + diff --git a/cmd/runtime_test.go b/cmd/runtime_test.go new file mode 100644 index 0000000..a738061 --- /dev/null +++ b/cmd/runtime_test.go @@ -0,0 +1,525 @@ +package cmd + +import ( + "testing" + + "github.com/karol-broda/snitch/internal/collector" +) + +func TestParseFilterArgs_Empty(t *testing.T) { + filters, err := ParseFilterArgs([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "" { + t.Errorf("expected empty proto, got %q", filters.Proto) + } +} + +func TestParseFilterArgs_Proto(t *testing.T) { + filters, err := ParseFilterArgs([]string{"proto=tcp"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "tcp" { + t.Errorf("expected proto 'tcp', got %q", filters.Proto) + } +} + +func TestParseFilterArgs_State(t *testing.T) { + filters, err := ParseFilterArgs([]string{"state=LISTEN"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.State != "LISTEN" { + t.Errorf("expected state 'LISTEN', got %q", filters.State) + } +} + +func TestParseFilterArgs_PID(t *testing.T) { + filters, err := ParseFilterArgs([]string{"pid=1234"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Pid != 1234 { + t.Errorf("expected pid 1234, got %d", filters.Pid) + } +} + +func TestParseFilterArgs_InvalidPID(t *testing.T) { + _, err := ParseFilterArgs([]string{"pid=notanumber"}) + if err == nil { + t.Error("expected error for invalid pid") + } +} + +func TestParseFilterArgs_Proc(t *testing.T) { + filters, err := ParseFilterArgs([]string{"proc=nginx"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proc != "nginx" { + t.Errorf("expected proc 'nginx', got %q", filters.Proc) + } +} + +func TestParseFilterArgs_Lport(t *testing.T) { + filters, err := ParseFilterArgs([]string{"lport=80"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Lport != 80 { + t.Errorf("expected lport 80, got %d", filters.Lport) + } +} + +func TestParseFilterArgs_InvalidLport(t *testing.T) { + _, err := ParseFilterArgs([]string{"lport=notaport"}) + if err == nil { + t.Error("expected error for invalid lport") + } +} + +func TestParseFilterArgs_Rport(t *testing.T) { + filters, err := ParseFilterArgs([]string{"rport=443"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Rport != 443 { + t.Errorf("expected rport 443, got %d", filters.Rport) + } +} + +func TestParseFilterArgs_InvalidRport(t *testing.T) { + _, err := ParseFilterArgs([]string{"rport=invalid"}) + if err == nil { + t.Error("expected error for invalid rport") + } +} + +func TestParseFilterArgs_UserByName(t *testing.T) { + filters, err := ParseFilterArgs([]string{"user=root"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.User != "root" { + t.Errorf("expected user 'root', got %q", filters.User) + } +} + +func TestParseFilterArgs_UserByUID(t *testing.T) { + filters, err := ParseFilterArgs([]string{"user=1000"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.UID != 1000 { + t.Errorf("expected uid 1000, got %d", filters.UID) + } +} + +func TestParseFilterArgs_Laddr(t *testing.T) { + filters, err := ParseFilterArgs([]string{"laddr=127.0.0.1"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Laddr != "127.0.0.1" { + t.Errorf("expected laddr '127.0.0.1', got %q", filters.Laddr) + } +} + +func TestParseFilterArgs_Raddr(t *testing.T) { + filters, err := ParseFilterArgs([]string{"raddr=8.8.8.8"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Raddr != "8.8.8.8" { + t.Errorf("expected raddr '8.8.8.8', got %q", filters.Raddr) + } +} + +func TestParseFilterArgs_Contains(t *testing.T) { + filters, err := ParseFilterArgs([]string{"contains=google"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Contains != "google" { + t.Errorf("expected contains 'google', got %q", filters.Contains) + } +} + +func TestParseFilterArgs_Interface(t *testing.T) { + filters, err := ParseFilterArgs([]string{"if=eth0"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Interface != "eth0" { + t.Errorf("expected interface 'eth0', got %q", filters.Interface) + } + + // test alternative syntax + filters2, err := ParseFilterArgs([]string{"interface=lo"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters2.Interface != "lo" { + t.Errorf("expected interface 'lo', got %q", filters2.Interface) + } +} + +func TestParseFilterArgs_Mark(t *testing.T) { + filters, err := ParseFilterArgs([]string{"mark=0x1234"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Mark != "0x1234" { + t.Errorf("expected mark '0x1234', got %q", filters.Mark) + } +} + +func TestParseFilterArgs_Namespace(t *testing.T) { + filters, err := ParseFilterArgs([]string{"namespace=default"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Namespace != "default" { + t.Errorf("expected namespace 'default', got %q", filters.Namespace) + } +} + +func TestParseFilterArgs_Inode(t *testing.T) { + filters, err := ParseFilterArgs([]string{"inode=123456"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Inode != 123456 { + t.Errorf("expected inode 123456, got %d", filters.Inode) + } +} + +func TestParseFilterArgs_InvalidInode(t *testing.T) { + _, err := ParseFilterArgs([]string{"inode=notanumber"}) + if err == nil { + t.Error("expected error for invalid inode") + } +} + +func TestParseFilterArgs_Multiple(t *testing.T) { + filters, err := ParseFilterArgs([]string{"proto=tcp", "state=LISTEN", "lport=80"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "tcp" { + t.Errorf("expected proto 'tcp', got %q", filters.Proto) + } + if filters.State != "LISTEN" { + t.Errorf("expected state 'LISTEN', got %q", filters.State) + } + if filters.Lport != 80 { + t.Errorf("expected lport 80, got %d", filters.Lport) + } +} + +func TestParseFilterArgs_InvalidFormat(t *testing.T) { + _, err := ParseFilterArgs([]string{"invalidformat"}) + if err == nil { + t.Error("expected error for invalid format") + } +} + +func TestParseFilterArgs_UnknownKey(t *testing.T) { + _, err := ParseFilterArgs([]string{"unknownkey=value"}) + if err == nil { + t.Error("expected error for unknown key") + } +} + +func TestParseFilterArgs_CaseInsensitiveKeys(t *testing.T) { + filters, err := ParseFilterArgs([]string{"PROTO=tcp", "State=LISTEN"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "tcp" { + t.Errorf("expected proto 'tcp', got %q", filters.Proto) + } + if filters.State != "LISTEN" { + t.Errorf("expected state 'LISTEN', got %q", filters.State) + } +} + +func TestBuildFilters_TCPOnly(t *testing.T) { + // save and restore global flags + oldTCP, oldUDP := filterTCP, filterUDP + defer func() { + filterTCP, filterUDP = oldTCP, oldUDP + }() + + filterTCP = true + filterUDP = false + + filters, err := BuildFilters([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "tcp" { + t.Errorf("expected proto 'tcp', got %q", filters.Proto) + } +} + +func TestBuildFilters_UDPOnly(t *testing.T) { + oldTCP, oldUDP := filterTCP, filterUDP + defer func() { + filterTCP, filterUDP = oldTCP, oldUDP + }() + + filterTCP = false + filterUDP = true + + filters, err := BuildFilters([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "udp" { + t.Errorf("expected proto 'udp', got %q", filters.Proto) + } +} + +func TestBuildFilters_ListenOnly(t *testing.T) { + oldListen, oldEstab := filterListen, filterEstab + defer func() { + filterListen, filterEstab = oldListen, oldEstab + }() + + filterListen = true + filterEstab = false + + filters, err := BuildFilters([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.State != "LISTEN" { + t.Errorf("expected state 'LISTEN', got %q", filters.State) + } +} + +func TestBuildFilters_EstablishedOnly(t *testing.T) { + oldListen, oldEstab := filterListen, filterEstab + defer func() { + filterListen, filterEstab = oldListen, oldEstab + }() + + filterListen = false + filterEstab = true + + filters, err := BuildFilters([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.State != "ESTABLISHED" { + t.Errorf("expected state 'ESTABLISHED', got %q", filters.State) + } +} + +func TestBuildFilters_IPv4Flag(t *testing.T) { + oldIPv4 := filterIPv4 + defer func() { + filterIPv4 = oldIPv4 + }() + + filterIPv4 = true + + filters, err := BuildFilters([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !filters.IPv4 { + t.Error("expected IPv4 to be true") + } +} + +func TestBuildFilters_IPv6Flag(t *testing.T) { + oldIPv6 := filterIPv6 + defer func() { + filterIPv6 = oldIPv6 + }() + + filterIPv6 = true + + filters, err := BuildFilters([]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !filters.IPv6 { + t.Error("expected IPv6 to be true") + } +} + +func TestBuildFilters_CombinedArgsAndFlags(t *testing.T) { + oldTCP := filterTCP + defer func() { + filterTCP = oldTCP + }() + + filterTCP = true + + filters, err := BuildFilters([]string{"lport=80"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filters.Proto != "tcp" { + t.Errorf("expected proto 'tcp', got %q", filters.Proto) + } + if filters.Lport != 80 { + t.Errorf("expected lport 80, got %d", filters.Lport) + } +} + +func TestRuntime_PreWarmDNS(t *testing.T) { + rt := &Runtime{ + Connections: []collector.Connection{ + {Laddr: "127.0.0.1", Raddr: "192.168.1.1"}, + {Laddr: "127.0.0.1", Raddr: "10.0.0.1"}, + }, + } + + // should not panic + rt.PreWarmDNS() +} + +func TestRuntime_PreWarmDNS_Empty(t *testing.T) { + rt := &Runtime{ + Connections: []collector.Connection{}, + } + + // should not panic with empty connections + rt.PreWarmDNS() +} + +func TestRuntime_SortConnections(t *testing.T) { + rt := &Runtime{ + Connections: []collector.Connection{ + {Lport: 443}, + {Lport: 80}, + {Lport: 8080}, + }, + } + + rt.SortConnections(collector.SortOptions{ + Field: collector.SortByLport, + Direction: collector.SortAsc, + }) + + if rt.Connections[0].Lport != 80 { + t.Errorf("expected first connection to have lport 80, got %d", rt.Connections[0].Lport) + } + if rt.Connections[1].Lport != 443 { + t.Errorf("expected second connection to have lport 443, got %d", rt.Connections[1].Lport) + } + if rt.Connections[2].Lport != 8080 { + t.Errorf("expected third connection to have lport 8080, got %d", rt.Connections[2].Lport) + } +} + +func TestRuntime_SortConnections_Desc(t *testing.T) { + rt := &Runtime{ + Connections: []collector.Connection{ + {Lport: 80}, + {Lport: 443}, + {Lport: 8080}, + }, + } + + rt.SortConnections(collector.SortOptions{ + Field: collector.SortByLport, + Direction: collector.SortDesc, + }) + + if rt.Connections[0].Lport != 8080 { + t.Errorf("expected first connection to have lport 8080, got %d", rt.Connections[0].Lport) + } +} + +func TestApplyFilter_AllKeys(t *testing.T) { + tests := []struct { + key string + value string + validate func(t *testing.T, f *collector.FilterOptions) + }{ + {"proto", "tcp", func(t *testing.T, f *collector.FilterOptions) { + if f.Proto != "tcp" { + t.Errorf("proto: expected 'tcp', got %q", f.Proto) + } + }}, + {"state", "LISTEN", func(t *testing.T, f *collector.FilterOptions) { + if f.State != "LISTEN" { + t.Errorf("state: expected 'LISTEN', got %q", f.State) + } + }}, + {"pid", "100", func(t *testing.T, f *collector.FilterOptions) { + if f.Pid != 100 { + t.Errorf("pid: expected 100, got %d", f.Pid) + } + }}, + {"proc", "nginx", func(t *testing.T, f *collector.FilterOptions) { + if f.Proc != "nginx" { + t.Errorf("proc: expected 'nginx', got %q", f.Proc) + } + }}, + {"lport", "80", func(t *testing.T, f *collector.FilterOptions) { + if f.Lport != 80 { + t.Errorf("lport: expected 80, got %d", f.Lport) + } + }}, + {"rport", "443", func(t *testing.T, f *collector.FilterOptions) { + if f.Rport != 443 { + t.Errorf("rport: expected 443, got %d", f.Rport) + } + }}, + {"laddr", "127.0.0.1", func(t *testing.T, f *collector.FilterOptions) { + if f.Laddr != "127.0.0.1" { + t.Errorf("laddr: expected '127.0.0.1', got %q", f.Laddr) + } + }}, + {"raddr", "8.8.8.8", func(t *testing.T, f *collector.FilterOptions) { + if f.Raddr != "8.8.8.8" { + t.Errorf("raddr: expected '8.8.8.8', got %q", f.Raddr) + } + }}, + {"contains", "test", func(t *testing.T, f *collector.FilterOptions) { + if f.Contains != "test" { + t.Errorf("contains: expected 'test', got %q", f.Contains) + } + }}, + {"if", "eth0", func(t *testing.T, f *collector.FilterOptions) { + if f.Interface != "eth0" { + t.Errorf("interface: expected 'eth0', got %q", f.Interface) + } + }}, + {"mark", "0xff", func(t *testing.T, f *collector.FilterOptions) { + if f.Mark != "0xff" { + t.Errorf("mark: expected '0xff', got %q", f.Mark) + } + }}, + {"namespace", "ns1", func(t *testing.T, f *collector.FilterOptions) { + if f.Namespace != "ns1" { + t.Errorf("namespace: expected 'ns1', got %q", f.Namespace) + } + }}, + {"inode", "12345", func(t *testing.T, f *collector.FilterOptions) { + if f.Inode != 12345 { + t.Errorf("inode: expected 12345, got %d", f.Inode) + } + }}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + filters := &collector.FilterOptions{} + err := applyFilter(filters, tt.key, tt.value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tt.validate(t, filters) + }) + } +} + diff --git a/cmd/top.go b/cmd/top.go index 33c617c..90b55a9 100644 --- a/cmd/top.go +++ b/cmd/top.go @@ -6,16 +6,15 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/karol-broda/snitch/internal/config" + "github.com/karol-broda/snitch/internal/resolver" "github.com/karol-broda/snitch/internal/tui" "github.com/spf13/cobra" ) // top-specific flags var ( - topTheme string - topInterval time.Duration - topResolveAddrs bool - topResolvePorts bool + topTheme string + topInterval time.Duration ) var topCmd = &cobra.Command{ @@ -29,11 +28,16 @@ var topCmd = &cobra.Command{ theme = cfg.Defaults.Theme } + // configure resolver with cache setting + effectiveNoCache := noCache || !cfg.Defaults.DNSCache + resolver.SetNoCache(effectiveNoCache) + opts := tui.Options{ Theme: theme, Interval: topInterval, - ResolveAddrs: topResolveAddrs, - ResolvePorts: topResolvePorts, + ResolveAddrs: resolveAddrs, + ResolvePorts: resolvePorts, + NoCache: effectiveNoCache, } // if any filter flag is set, use exclusive mode @@ -62,9 +66,8 @@ func init() { // top-specific flags topCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')") topCmd.Flags().DurationVarP(&topInterval, "interval", "i", time.Second, "Refresh interval") - topCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames") - topCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names") - // shared filter flags + // shared flags addFilterFlags(topCmd) + addResolutionFlags(topCmd) } \ No newline at end of file diff --git a/cmd/trace.go b/cmd/trace.go index 6a5d68d..b561ef8 100644 --- a/cmd/trace.go +++ b/cmd/trace.go @@ -7,12 +7,14 @@ import ( "log" "os" "os/signal" - "github.com/karol-broda/snitch/internal/collector" - "github.com/karol-broda/snitch/internal/resolver" "strings" "syscall" "time" + "github.com/karol-broda/snitch/internal/collector" + "github.com/karol-broda/snitch/internal/config" + "github.com/karol-broda/snitch/internal/resolver" + "github.com/spf13/cobra" ) @@ -23,11 +25,10 @@ type TraceEvent struct { } var ( - traceInterval time.Duration - traceCount int + traceInterval time.Duration + traceCount int traceOutputFormat string - traceNumeric bool - traceTimestamp bool + traceTimestamp bool ) var traceCmd = &cobra.Command{ @@ -47,6 +48,12 @@ Available filters: } func runTraceCommand(args []string) { + cfg := config.Get() + + // configure resolver with cache setting + effectiveNoCache := noCache || !cfg.Defaults.DNSCache + resolver.SetNoCache(effectiveNoCache) + filters, err := BuildFilters(args) if err != nil { log.Fatalf("Error parsing filters: %v", err) @@ -180,14 +187,16 @@ func printTraceEventHuman(event TraceEvent) { lportStr := fmt.Sprintf("%d", conn.Lport) rportStr := fmt.Sprintf("%d", conn.Rport) - // Handle name resolution based on numeric flag - if !traceNumeric { + // apply name resolution + if resolveAddrs { if resolvedLaddr := resolver.ResolveAddr(conn.Laddr); resolvedLaddr != conn.Laddr { laddr = resolvedLaddr } if resolvedRaddr := resolver.ResolveAddr(conn.Raddr); resolvedRaddr != conn.Raddr && conn.Raddr != "*" && conn.Raddr != "" { raddr = resolvedRaddr } + } + if resolvePorts { if resolvedLport := resolver.ResolvePort(conn.Lport, conn.Proto); resolvedLport != fmt.Sprintf("%d", conn.Lport) { lportStr = resolvedLport } @@ -225,9 +234,9 @@ func init() { traceCmd.Flags().DurationVarP(&traceInterval, "interval", "i", time.Second, "Polling interval (e.g., 500ms, 2s)") traceCmd.Flags().IntVarP(&traceCount, "count", "c", 0, "Number of events to capture (0 = unlimited)") traceCmd.Flags().StringVarP(&traceOutputFormat, "output", "o", "human", "Output format (human, json)") - traceCmd.Flags().BoolVarP(&traceNumeric, "numeric", "n", false, "Don't resolve hostnames") traceCmd.Flags().BoolVar(&traceTimestamp, "ts", false, "Include timestamp in output") - // shared filter flags + // shared flags addFilterFlags(traceCmd) + addResolutionFlags(traceCmd) } diff --git a/internal/config/config.go b/internal/config/config.go index 3e0f87d..64f5e87 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,7 @@ type DefaultConfig struct { Units string `mapstructure:"units"` Color string `mapstructure:"color"` Resolve bool `mapstructure:"resolve"` + DNSCache bool `mapstructure:"dns_cache"` IPv4 bool `mapstructure:"ipv4"` IPv6 bool `mapstructure:"ipv6"` NoHeaders bool `mapstructure:"no_headers"` @@ -57,6 +58,7 @@ func Load() (*Config, error) { // environment variable bindings for readme-documented variables _ = v.BindEnv("config", "SNITCH_CONFIG") _ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE") + _ = v.BindEnv("defaults.dns_cache", "SNITCH_DNS_CACHE") _ = v.BindEnv("defaults.theme", "SNITCH_THEME") _ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR") @@ -90,7 +92,6 @@ func Load() (*Config, error) { } func setDefaults(v *viper.Viper) { - // Set default values matching the README specification v.SetDefault("defaults.interval", "1s") v.SetDefault("defaults.numeric", false) v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}) @@ -98,6 +99,7 @@ func setDefaults(v *viper.Viper) { v.SetDefault("defaults.units", "auto") v.SetDefault("defaults.color", "auto") v.SetDefault("defaults.resolve", true) + v.SetDefault("defaults.dns_cache", true) v.SetDefault("defaults.ipv4", false) v.SetDefault("defaults.ipv6", false) v.SetDefault("defaults.no_headers", false) @@ -116,6 +118,11 @@ func handleSpecialEnvVars(v *viper.Viper) { v.Set("defaults.resolve", false) v.Set("defaults.numeric", true) } + + // Handle SNITCH_DNS_CACHE - if set to "0", disable dns caching + if os.Getenv("SNITCH_DNS_CACHE") == "0" { + v.Set("defaults.dns_cache", false) + } } // Get returns the global configuration, loading it if necessary @@ -123,7 +130,6 @@ func Get() *Config { if globalConfig == nil { config, err := Load() if err != nil { - // Return default config on error return &Config{ Defaults: DefaultConfig{ Interval: "1s", @@ -133,6 +139,7 @@ func Get() *Config { Units: "auto", Color: "auto", Resolve: true, + DNSCache: true, IPv4: false, IPv6: false, NoHeaders: false, diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index deed367..a8fdaa7 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -14,9 +14,10 @@ var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != "" // Resolver handles DNS and service name resolution with caching and timeouts type Resolver struct { - timeout time.Duration - cache map[string]string - mutex sync.RWMutex + timeout time.Duration + cache map[string]string + mutex sync.RWMutex + noCache bool } // New creates a new resolver with the specified timeout @@ -24,37 +25,44 @@ func New(timeout time.Duration) *Resolver { return &Resolver{ timeout: timeout, cache: make(map[string]string), + noCache: false, } } +// SetNoCache disables caching - each lookup will hit DNS directly +func (r *Resolver) SetNoCache(noCache bool) { + r.noCache = noCache +} + // ResolveAddr resolves an IP address to a hostname, with caching func (r *Resolver) ResolveAddr(addr string) string { - // Check cache first - r.mutex.RLock() - if cached, exists := r.cache[addr]; exists { + // check cache first (unless caching is disabled) + if !r.noCache { + r.mutex.RLock() + if cached, exists := r.cache[addr]; exists { + r.mutex.RUnlock() + return cached + } r.mutex.RUnlock() - return cached } - r.mutex.RUnlock() - // Parse IP to validate it + // parse ip to validate it ip := net.ParseIP(addr) if ip == nil { - // Not a valid IP, return as-is return addr } - // Perform resolution with timeout + // perform resolution with timeout start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() names, err := net.DefaultResolver.LookupAddr(ctx, addr) - resolved := addr // fallback to original address + resolved := addr if err == nil && len(names) > 0 { resolved = names[0] - // Remove trailing dot if present + // remove trailing dot if present if len(resolved) > 0 && resolved[len(resolved)-1] == '.' { resolved = resolved[:len(resolved)-1] } @@ -65,10 +73,12 @@ func (r *Resolver) ResolveAddr(addr string) string { fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed) } - // Cache the result - r.mutex.Lock() - r.cache[addr] = resolved - r.mutex.Unlock() + // cache the result (unless caching is disabled) + if !r.noCache { + r.mutex.Lock() + r.cache[addr] = resolved + r.mutex.Unlock() + } return resolved } @@ -81,15 +91,17 @@ func (r *Resolver) ResolvePort(port int, proto string) string { cacheKey := strconv.Itoa(port) + "/" + proto - // Check cache first - r.mutex.RLock() - if cached, exists := r.cache[cacheKey]; exists { + // check cache first (unless caching is disabled) + if !r.noCache { + r.mutex.RLock() + if cached, exists := r.cache[cacheKey]; exists { + r.mutex.RUnlock() + return cached + } r.mutex.RUnlock() - return cached } - r.mutex.RUnlock() - // Perform resolution with timeout + // perform resolution with timeout ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() @@ -97,16 +109,18 @@ func (r *Resolver) ResolvePort(port int, proto string) string { resolved := strconv.Itoa(port) // fallback to port number if err == nil && service != 0 { - // Try to get service name + // try to get service name if serviceName := getServiceName(port, proto); serviceName != "" { resolved = serviceName } } - // Cache the result - r.mutex.Lock() - r.cache[cacheKey] = resolved - r.mutex.Unlock() + // cache the result (unless caching is disabled) + if !r.noCache { + r.mutex.Lock() + r.cache[cacheKey] = resolved + r.mutex.Unlock() + } return resolved } @@ -169,22 +183,38 @@ func getServiceName(port int, proto string) string { return "" } -// Global resolver instance +// global resolver instance var globalResolver *Resolver -// SetGlobalResolver sets the global resolver instance -func SetGlobalResolver(timeout time.Duration) { +// ResolverOptions configures the global resolver +type ResolverOptions struct { + Timeout time.Duration + NoCache bool +} + +// SetGlobalResolver sets the global resolver instance with options +func SetGlobalResolver(opts ResolverOptions) { + timeout := opts.Timeout + if timeout == 0 { + timeout = 200 * time.Millisecond + } globalResolver = New(timeout) + globalResolver.SetNoCache(opts.NoCache) } // GetGlobalResolver returns the global resolver instance func GetGlobalResolver() *Resolver { if globalResolver == nil { - globalResolver = New(200 * time.Millisecond) // Default timeout + globalResolver = New(200 * time.Millisecond) } return globalResolver } +// SetNoCache configures whether the global resolver bypasses cache +func SetNoCache(noCache bool) { + GetGlobalResolver().SetNoCache(noCache) +} + // ResolveAddr is a convenience function using the global resolver func ResolveAddr(addr string) string { return GetGlobalResolver().ResolveAddr(addr) @@ -199,3 +229,48 @@ func ResolvePort(port int, proto string) string { func ResolveAddrPort(addr string, port int, proto string) (string, string) { return GetGlobalResolver().ResolveAddrPort(addr, port, proto) } + +// ResolveAddrsParallel resolves multiple addresses concurrently and caches results. +// This should be called before rendering to pre-warm the cache. +func (r *Resolver) ResolveAddrsParallel(addrs []string) { + // dedupe and filter addresses that need resolution + unique := make(map[string]struct{}) + for _, addr := range addrs { + if addr == "" || addr == "*" { + continue + } + // skip if already cached + r.mutex.RLock() + _, exists := r.cache[addr] + r.mutex.RUnlock() + if exists { + continue + } + unique[addr] = struct{}{} + } + + if len(unique) == 0 { + return + } + + var wg sync.WaitGroup + // limit concurrency to avoid overwhelming dns + sem := make(chan struct{}, 32) + + for addr := range unique { + wg.Add(1) + go func(a string) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + r.ResolveAddr(a) + }(addr) + } + + wg.Wait() +} + +// ResolveAddrsParallel is a convenience function using the global resolver +func ResolveAddrsParallel(addrs []string) { + GetGlobalResolver().ResolveAddrsParallel(addrs) +} diff --git a/internal/resolver/resolver_bench_test.go b/internal/resolver/resolver_bench_test.go new file mode 100644 index 0000000..85decc4 --- /dev/null +++ b/internal/resolver/resolver_bench_test.go @@ -0,0 +1,159 @@ +package resolver + +import ( + "fmt" + "testing" + "time" +) + +func BenchmarkResolveAddr_CacheHit(b *testing.B) { + r := New(100 * time.Millisecond) + addr := "127.0.0.1" + + // pre-populate cache + r.ResolveAddr(addr) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.ResolveAddr(addr) + } +} + +func BenchmarkResolveAddr_CacheMiss(b *testing.B) { + r := New(10 * time.Millisecond) // short timeout for faster benchmarks + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // use different addresses to avoid cache hits + addr := fmt.Sprintf("127.0.0.%d", i%256) + r.ClearCache() // clear cache to force miss + r.ResolveAddr(addr) + } +} + +func BenchmarkResolveAddr_NoCache(b *testing.B) { + r := New(10 * time.Millisecond) + r.SetNoCache(true) + addr := "127.0.0.1" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.ResolveAddr(addr) + } +} + +func BenchmarkResolvePort_CacheHit(b *testing.B) { + r := New(100 * time.Millisecond) + + // pre-populate cache + r.ResolvePort(80, "tcp") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.ResolvePort(80, "tcp") + } +} + +func BenchmarkResolvePort_WellKnown(b *testing.B) { + r := New(100 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.ClearCache() + r.ResolvePort(443, "tcp") + } +} + +func BenchmarkGetServiceName(b *testing.B) { + for i := 0; i < b.N; i++ { + getServiceName(80, "tcp") + } +} + +func BenchmarkGetServiceName_NotFound(b *testing.B) { + for i := 0; i < b.N; i++ { + getServiceName(12345, "tcp") + } +} + +func BenchmarkResolveAddrsParallel_10(b *testing.B) { + benchmarkResolveAddrsParallel(b, 10) +} + +func BenchmarkResolveAddrsParallel_100(b *testing.B) { + benchmarkResolveAddrsParallel(b, 100) +} + +func BenchmarkResolveAddrsParallel_1000(b *testing.B) { + benchmarkResolveAddrsParallel(b, 1000) +} + +func benchmarkResolveAddrsParallel(b *testing.B, count int) { + addrs := make([]string, count) + for i := 0; i < count; i++ { + addrs[i] = fmt.Sprintf("127.0.%d.%d", i/256, i%256) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r := New(10 * time.Millisecond) + r.ResolveAddrsParallel(addrs) + } +} + +func BenchmarkConcurrentResolveAddr(b *testing.B) { + r := New(100 * time.Millisecond) + addr := "127.0.0.1" + + // pre-populate cache + r.ResolveAddr(addr) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + r.ResolveAddr(addr) + } + }) +} + +func BenchmarkConcurrentResolvePort(b *testing.B) { + r := New(100 * time.Millisecond) + + // pre-populate cache + r.ResolvePort(80, "tcp") + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + r.ResolvePort(80, "tcp") + } + }) +} + +func BenchmarkGetCacheSize(b *testing.B) { + r := New(100 * time.Millisecond) + + // populate with some entries + for i := 0; i < 100; i++ { + r.ResolvePort(i+1, "tcp") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.GetCacheSize() + } +} + +func BenchmarkClearCache(b *testing.B) { + r := New(100 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // populate and clear + for j := 0; j < 10; j++ { + r.ResolvePort(j+1, "tcp") + } + r.ClearCache() + } +} + diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go new file mode 100644 index 0000000..4c9eb84 --- /dev/null +++ b/internal/resolver/resolver_test.go @@ -0,0 +1,387 @@ +package resolver + +import ( + "sync" + "testing" + "time" +) + +func TestNew(t *testing.T) { + r := New(100 * time.Millisecond) + if r == nil { + t.Fatal("expected non-nil resolver") + } + if r.timeout != 100*time.Millisecond { + t.Errorf("expected timeout 100ms, got %v", r.timeout) + } + if r.cache == nil { + t.Error("expected cache to be initialized") + } + if r.noCache { + t.Error("expected noCache to be false by default") + } +} + +func TestSetNoCache(t *testing.T) { + r := New(100 * time.Millisecond) + + r.SetNoCache(true) + if !r.noCache { + t.Error("expected noCache to be true") + } + + r.SetNoCache(false) + if r.noCache { + t.Error("expected noCache to be false") + } +} + +func TestResolveAddr_InvalidIP(t *testing.T) { + r := New(100 * time.Millisecond) + + // invalid ip should return as-is + result := r.ResolveAddr("not-an-ip") + if result != "not-an-ip" { + t.Errorf("expected 'not-an-ip', got %q", result) + } + + // empty string should return as-is + result = r.ResolveAddr("") + if result != "" { + t.Errorf("expected empty string, got %q", result) + } +} + +func TestResolveAddr_Caching(t *testing.T) { + r := New(100 * time.Millisecond) + + // first call should cache + addr := "127.0.0.1" + result1 := r.ResolveAddr(addr) + + // verify cache is populated + if r.GetCacheSize() != 1 { + t.Errorf("expected cache size 1, got %d", r.GetCacheSize()) + } + + // second call should use cache + result2 := r.ResolveAddr(addr) + if result1 != result2 { + t.Errorf("expected same result from cache, got %q and %q", result1, result2) + } +} + +func TestResolveAddr_NoCacheMode(t *testing.T) { + r := New(100 * time.Millisecond) + r.SetNoCache(true) + + addr := "127.0.0.1" + r.ResolveAddr(addr) + + // cache should remain empty when noCache is enabled + if r.GetCacheSize() != 0 { + t.Errorf("expected cache size 0 with noCache, got %d", r.GetCacheSize()) + } +} + +func TestResolvePort_Zero(t *testing.T) { + r := New(100 * time.Millisecond) + + result := r.ResolvePort(0, "tcp") + if result != "0" { + t.Errorf("expected '0' for port 0, got %q", result) + } +} + +func TestResolvePort_WellKnown(t *testing.T) { + r := New(100 * time.Millisecond) + + tests := []struct { + port int + proto string + expected string + }{ + {80, "tcp", "http"}, + {443, "tcp", "https"}, + {22, "tcp", "ssh"}, + {53, "udp", "domain"}, + {5432, "tcp", "postgresql"}, + } + + for _, tt := range tests { + result := r.ResolvePort(tt.port, tt.proto) + if result != tt.expected { + t.Errorf("ResolvePort(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected) + } + } +} + +func TestResolvePort_Caching(t *testing.T) { + r := New(100 * time.Millisecond) + + r.ResolvePort(80, "tcp") + r.ResolvePort(443, "tcp") + + if r.GetCacheSize() != 2 { + t.Errorf("expected cache size 2, got %d", r.GetCacheSize()) + } + + // same port/proto should not add new entry + r.ResolvePort(80, "tcp") + if r.GetCacheSize() != 2 { + t.Errorf("expected cache size still 2, got %d", r.GetCacheSize()) + } +} + +func TestResolveAddrPort(t *testing.T) { + r := New(100 * time.Millisecond) + + addr, port := r.ResolveAddrPort("127.0.0.1", 80, "tcp") + + if addr == "" { + t.Error("expected non-empty address") + } + if port != "http" { + t.Errorf("expected port 'http', got %q", port) + } +} + +func TestClearCache(t *testing.T) { + r := New(100 * time.Millisecond) + + r.ResolveAddr("127.0.0.1") + r.ResolvePort(80, "tcp") + + if r.GetCacheSize() == 0 { + t.Error("expected non-empty cache before clear") + } + + r.ClearCache() + + if r.GetCacheSize() != 0 { + t.Errorf("expected empty cache after clear, got %d", r.GetCacheSize()) + } +} + +func TestGetCacheSize(t *testing.T) { + r := New(100 * time.Millisecond) + + if r.GetCacheSize() != 0 { + t.Errorf("expected initial cache size 0, got %d", r.GetCacheSize()) + } + + r.ResolveAddr("127.0.0.1") + if r.GetCacheSize() != 1 { + t.Errorf("expected cache size 1, got %d", r.GetCacheSize()) + } +} + +func TestGetServiceName(t *testing.T) { + tests := []struct { + port int + proto string + expected string + }{ + {80, "tcp", "http"}, + {443, "tcp", "https"}, + {22, "tcp", "ssh"}, + {53, "tcp", "domain"}, + {53, "udp", "domain"}, + {12345, "tcp", ""}, + {0, "tcp", ""}, + } + + for _, tt := range tests { + result := getServiceName(tt.port, tt.proto) + if result != tt.expected { + t.Errorf("getServiceName(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected) + } + } +} + +func TestResolveAddrsParallel(t *testing.T) { + r := New(100 * time.Millisecond) + + addrs := []string{ + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + "", // should be skipped + "*", // should be skipped + } + + r.ResolveAddrsParallel(addrs) + + // should have cached 3 addresses (excluding empty and *) + if r.GetCacheSize() != 3 { + t.Errorf("expected cache size 3, got %d", r.GetCacheSize()) + } +} + +func TestResolveAddrsParallel_Dedupe(t *testing.T) { + r := New(100 * time.Millisecond) + + addrs := []string{ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.2", + } + + r.ResolveAddrsParallel(addrs) + + // should have cached 2 unique addresses + if r.GetCacheSize() != 2 { + t.Errorf("expected cache size 2, got %d", r.GetCacheSize()) + } +} + +func TestResolveAddrsParallel_SkipsCached(t *testing.T) { + r := New(100 * time.Millisecond) + + // pre-cache one address + r.ResolveAddr("127.0.0.1") + + addrs := []string{ + "127.0.0.1", // already cached + "127.0.0.2", // not cached + } + + initialSize := r.GetCacheSize() + r.ResolveAddrsParallel(addrs) + + // should have added 1 more + if r.GetCacheSize() != initialSize+1 { + t.Errorf("expected cache size %d, got %d", initialSize+1, r.GetCacheSize()) + } +} + +func TestResolveAddrsParallel_Empty(t *testing.T) { + r := New(100 * time.Millisecond) + + // should not panic with empty input + r.ResolveAddrsParallel([]string{}) + r.ResolveAddrsParallel(nil) + + if r.GetCacheSize() != 0 { + t.Errorf("expected cache size 0, got %d", r.GetCacheSize()) + } +} + +func TestGlobalResolver(t *testing.T) { + // reset global resolver + globalResolver = nil + + r := GetGlobalResolver() + if r == nil { + t.Fatal("expected non-nil global resolver") + } + + // should return same instance + r2 := GetGlobalResolver() + if r != r2 { + t.Error("expected same global resolver instance") + } +} + +func TestSetGlobalResolver(t *testing.T) { + SetGlobalResolver(ResolverOptions{ + Timeout: 500 * time.Millisecond, + NoCache: true, + }) + + r := GetGlobalResolver() + if r.timeout != 500*time.Millisecond { + t.Errorf("expected timeout 500ms, got %v", r.timeout) + } + if !r.noCache { + t.Error("expected noCache to be true") + } + + // reset for other tests + globalResolver = nil +} + +func TestSetGlobalResolver_DefaultTimeout(t *testing.T) { + SetGlobalResolver(ResolverOptions{ + Timeout: 0, // should use default + }) + + r := GetGlobalResolver() + if r.timeout != 200*time.Millisecond { + t.Errorf("expected default timeout 200ms, got %v", r.timeout) + } + + // reset for other tests + globalResolver = nil +} + +func TestGlobalConvenienceFunctions(t *testing.T) { + globalResolver = nil + + // test global ResolveAddr + result := ResolveAddr("127.0.0.1") + if result == "" { + t.Error("expected non-empty result from global ResolveAddr") + } + + // test global ResolvePort + port := ResolvePort(80, "tcp") + if port != "http" { + t.Errorf("expected 'http', got %q", port) + } + + // test global ResolveAddrPort + addr, portStr := ResolveAddrPort("127.0.0.1", 443, "tcp") + if addr == "" { + t.Error("expected non-empty address") + } + if portStr != "https" { + t.Errorf("expected 'https', got %q", portStr) + } + + // test global SetNoCache + SetNoCache(true) + if !GetGlobalResolver().noCache { + t.Error("expected global noCache to be true") + } + + // reset + globalResolver = nil +} + +func TestConcurrentAccess(t *testing.T) { + r := New(100 * time.Millisecond) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + addr := "127.0.0.1" + r.ResolveAddr(addr) + r.ResolvePort(80+n%10, "tcp") + r.GetCacheSize() + }(i) + } + + wg.Wait() + + // should not panic and cache should have entries + if r.GetCacheSize() == 0 { + t.Error("expected non-empty cache after concurrent access") + } +} + +func TestResolveAddr_TrailingDot(t *testing.T) { + // this test verifies the trailing dot removal logic + // by checking the internal logic works correctly + r := New(100 * time.Millisecond) + + // localhost should resolve and have trailing dot removed + result := r.ResolveAddr("127.0.0.1") + if len(result) > 0 && result[len(result)-1] == '.' { + t.Error("expected trailing dot to be removed") + } +} + diff --git a/internal/tui/messages.go b/internal/tui/messages.go index f9b8ca6..71e2996 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -3,6 +3,7 @@ package tui import ( "fmt" "github.com/karol-broda/snitch/internal/collector" + "github.com/karol-broda/snitch/internal/resolver" "syscall" "time" @@ -35,11 +36,20 @@ func (m model) tick() tea.Cmd { } func (m model) fetchData() tea.Cmd { + resolveAddrs := m.resolveAddrs return func() tea.Msg { conns, err := collector.GetConnections() if err != nil { return errMsg{err} } + // pre-warm dns cache in parallel if resolution is enabled + if resolveAddrs { + addrs := make([]string, 0, len(conns)*2) + for _, c := range conns { + addrs = append(addrs, c.Laddr, c.Raddr) + } + resolver.ResolveAddrsParallel(addrs) + } return dataMsg{connections: conns} } } diff --git a/internal/tui/model.go b/internal/tui/model.go index 9698df1..d3c9e3f 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -64,6 +64,7 @@ type Options struct { FilterSet bool // true if user specified any filter flags ResolveAddrs bool // when true, resolve IP addresses to hostnames ResolvePorts bool // when true, resolve port numbers to service names + NoCache bool // when true, disable DNS caching } func New(opts Options) model {