fix: dns resolution taking long and add caching options (#8)
This commit is contained in:
28
README.md
28
README.md
@@ -167,9 +167,20 @@ shortcut flags work on all commands:
|
|||||||
-e, --established established connections
|
-e, --established established connections
|
||||||
-4, --ipv4 ipv4 only
|
-4, --ipv4 ipv4 only
|
||||||
-6, --ipv6 ipv6 only
|
-6, --ipv6 ipv6 only
|
||||||
-n, --numeric no dns resolution
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## resolution
|
||||||
|
|
||||||
|
dns and service name resolution options:
|
||||||
|
|
||||||
|
```
|
||||||
|
--resolve-addrs resolve ip addresses to hostnames (default: true)
|
||||||
|
--resolve-ports resolve port numbers to service names
|
||||||
|
--no-cache disable dns caching (force fresh lookups)
|
||||||
|
```
|
||||||
|
|
||||||
|
dns lookups are performed in parallel and cached for performance. use `--no-cache` to bypass the cache for debugging or when addresses change frequently.
|
||||||
|
|
||||||
for more specific filtering, use `key=value` syntax with `ls`:
|
for more specific filtering, use `key=value` syntax with `ls`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -208,8 +219,19 @@ optional config file at `~/.config/snitch/snitch.toml`:
|
|||||||
|
|
||||||
```toml
|
```toml
|
||||||
[defaults]
|
[defaults]
|
||||||
numeric = false
|
numeric = false # disable name resolution
|
||||||
theme = "auto"
|
dns_cache = true # cache dns lookups (set to false to disable)
|
||||||
|
theme = "auto" # color theme: auto, dark, light, mono
|
||||||
|
```
|
||||||
|
|
||||||
|
### environment variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
SNITCH_THEME=dark # set default theme
|
||||||
|
SNITCH_RESOLVE=0 # disable dns resolution
|
||||||
|
SNITCH_DNS_CACHE=0 # disable dns caching
|
||||||
|
SNITCH_NO_COLOR=1 # disable color output
|
||||||
|
SNITCH_CONFIG=/path/to # custom config file path
|
||||||
```
|
```
|
||||||
|
|
||||||
## requirements
|
## requirements
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ var (
|
|||||||
sortBy string
|
sortBy string
|
||||||
fields string
|
fields string
|
||||||
colorMode string
|
colorMode string
|
||||||
resolveAddrs bool
|
|
||||||
resolvePorts bool
|
|
||||||
plainOutput bool
|
plainOutput bool
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -400,10 +398,9 @@ func init() {
|
|||||||
lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)")
|
lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)")
|
||||||
lsCmd.Flags().StringVarP(&fields, "fields", "f", strings.Join(cfg.Defaults.Fields, ","), "Comma-separated list of fields to show")
|
lsCmd.Flags().StringVarP(&fields, "fields", "f", strings.Join(cfg.Defaults.Fields, ","), "Comma-separated list of fields to show")
|
||||||
lsCmd.Flags().StringVar(&colorMode, "color", cfg.Defaults.Color, "Color mode (auto, always, never)")
|
lsCmd.Flags().StringVar(&colorMode, "color", cfg.Defaults.Color, "Color mode (auto, always, never)")
|
||||||
lsCmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
|
||||||
lsCmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
|
||||||
lsCmd.Flags().BoolVarP(&plainOutput, "plain", "p", false, "Plain output (parsable, no styling)")
|
lsCmd.Flags().BoolVarP(&plainOutput, "plain", "p", false, "Plain output (parsable, no styling)")
|
||||||
|
|
||||||
// shared filter flags
|
// shared flags
|
||||||
addFilterFlags(lsCmd)
|
addFilterFlags(lsCmd)
|
||||||
|
addResolutionFlags(lsCmd)
|
||||||
}
|
}
|
||||||
@@ -45,9 +45,8 @@ func init() {
|
|||||||
cfg := config.Get()
|
cfg := config.Get()
|
||||||
rootCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
|
rootCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
|
||||||
rootCmd.Flags().DurationVarP(&topInterval, "interval", "i", 0, "Refresh interval (default 1s)")
|
rootCmd.Flags().DurationVarP(&topInterval, "interval", "i", 0, "Refresh interval (default 1s)")
|
||||||
rootCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
|
||||||
rootCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
|
||||||
|
|
||||||
// shared filter flags for root command
|
// shared flags for root command
|
||||||
addFilterFlags(rootCmd)
|
addFilterFlags(rootCmd)
|
||||||
|
addResolutionFlags(rootCmd)
|
||||||
}
|
}
|
||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/karol-broda/snitch/internal/collector"
|
"github.com/karol-broda/snitch/internal/collector"
|
||||||
"github.com/karol-broda/snitch/internal/color"
|
"github.com/karol-broda/snitch/internal/color"
|
||||||
|
"github.com/karol-broda/snitch/internal/config"
|
||||||
|
"github.com/karol-broda/snitch/internal/resolver"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -11,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Runtime holds the shared state for all commands.
|
// Runtime holds the shared state for all commands.
|
||||||
// it handles common filter logic, fetching, and filtering connections.
|
// it handles common filter logic, fetching, filtering, and resolution.
|
||||||
type Runtime struct {
|
type Runtime struct {
|
||||||
// filter options built from flags and args
|
// filter options built from flags and args
|
||||||
Filters collector.FilterOptions
|
Filters collector.FilterOptions
|
||||||
@@ -23,6 +25,7 @@ type Runtime struct {
|
|||||||
ColorMode string
|
ColorMode string
|
||||||
ResolveAddrs bool
|
ResolveAddrs bool
|
||||||
ResolvePorts bool
|
ResolvePorts bool
|
||||||
|
NoCache bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// shared filter flags - used by all commands
|
// shared filter flags - used by all commands
|
||||||
@@ -35,6 +38,13 @@ var (
|
|||||||
filterIPv6 bool
|
filterIPv6 bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// shared resolution flags - used by all commands
|
||||||
|
var (
|
||||||
|
resolveAddrs bool
|
||||||
|
resolvePorts bool
|
||||||
|
noCache bool
|
||||||
|
)
|
||||||
|
|
||||||
// BuildFilters constructs FilterOptions from command args and shortcut flags.
|
// BuildFilters constructs FilterOptions from command args and shortcut flags.
|
||||||
func BuildFilters(args []string) (collector.FilterOptions, error) {
|
func BuildFilters(args []string) (collector.FilterOptions, error) {
|
||||||
filters, err := ParseFilterArgs(args)
|
filters, err := ParseFilterArgs(args)
|
||||||
@@ -77,6 +87,12 @@ func FetchConnections(filters collector.FilterOptions) ([]collector.Connection,
|
|||||||
func NewRuntime(args []string, colorMode string) (*Runtime, error) {
|
func NewRuntime(args []string, colorMode string) (*Runtime, error) {
|
||||||
color.Init(colorMode)
|
color.Init(colorMode)
|
||||||
|
|
||||||
|
cfg := config.Get()
|
||||||
|
|
||||||
|
// configure resolver with cache setting (flag overrides config)
|
||||||
|
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
|
||||||
|
resolver.SetNoCache(effectiveNoCache)
|
||||||
|
|
||||||
filters, err := BuildFilters(args)
|
filters, err := BuildFilters(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
||||||
@@ -87,13 +103,30 @@ func NewRuntime(args []string, colorMode string) (*Runtime, error) {
|
|||||||
return nil, fmt.Errorf("failed to fetch connections: %w", err)
|
return nil, fmt.Errorf("failed to fetch connections: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Runtime{
|
rt := &Runtime{
|
||||||
Filters: filters,
|
Filters: filters,
|
||||||
Connections: connections,
|
Connections: connections,
|
||||||
ColorMode: colorMode,
|
ColorMode: colorMode,
|
||||||
ResolveAddrs: resolveAddrs,
|
ResolveAddrs: resolveAddrs,
|
||||||
ResolvePorts: resolvePorts,
|
ResolvePorts: resolvePorts,
|
||||||
}, nil
|
NoCache: effectiveNoCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
// pre-warm dns cache by resolving all addresses in parallel
|
||||||
|
if resolveAddrs {
|
||||||
|
rt.PreWarmDNS()
|
||||||
|
}
|
||||||
|
|
||||||
|
return rt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreWarmDNS resolves all connection addresses in parallel to warm the cache.
|
||||||
|
func (r *Runtime) PreWarmDNS() {
|
||||||
|
addrs := make([]string, 0, len(r.Connections)*2)
|
||||||
|
for _, c := range r.Connections {
|
||||||
|
addrs = append(addrs, c.Laddr, c.Raddr)
|
||||||
|
}
|
||||||
|
resolver.ResolveAddrsParallel(addrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SortConnections sorts the runtime's connections in place.
|
// SortConnections sorts the runtime's connections in place.
|
||||||
@@ -201,3 +234,11 @@ func addFilterFlags(cmd *cobra.Command) {
|
|||||||
cmd.Flags().BoolVarP(&filterIPv6, "ipv6", "6", false, "Only show IPv6 connections")
|
cmd.Flags().BoolVarP(&filterIPv6, "ipv6", "6", false, "Only show IPv6 connections")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addResolutionFlags adds the common resolution flags to a command.
|
||||||
|
func addResolutionFlags(cmd *cobra.Command) {
|
||||||
|
cfg := config.Get()
|
||||||
|
cmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
||||||
|
cmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
||||||
|
cmd.Flags().BoolVar(&noCache, "no-cache", !cfg.Defaults.DNSCache, "Disable DNS caching (force fresh lookups)")
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
525
cmd/runtime_test.go
Normal file
525
cmd/runtime_test.go
Normal file
@@ -0,0 +1,525 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/karol-broda/snitch/internal/collector"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Empty(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "" {
|
||||||
|
t.Errorf("expected empty proto, got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Proto(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"proto=tcp"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "tcp" {
|
||||||
|
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_State(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"state=LISTEN"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.State != "LISTEN" {
|
||||||
|
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_PID(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"pid=1234"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Pid != 1234 {
|
||||||
|
t.Errorf("expected pid 1234, got %d", filters.Pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_InvalidPID(t *testing.T) {
|
||||||
|
_, err := ParseFilterArgs([]string{"pid=notanumber"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid pid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Proc(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"proc=nginx"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proc != "nginx" {
|
||||||
|
t.Errorf("expected proc 'nginx', got %q", filters.Proc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Lport(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"lport=80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Lport != 80 {
|
||||||
|
t.Errorf("expected lport 80, got %d", filters.Lport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_InvalidLport(t *testing.T) {
|
||||||
|
_, err := ParseFilterArgs([]string{"lport=notaport"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid lport")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Rport(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"rport=443"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Rport != 443 {
|
||||||
|
t.Errorf("expected rport 443, got %d", filters.Rport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_InvalidRport(t *testing.T) {
|
||||||
|
_, err := ParseFilterArgs([]string{"rport=invalid"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid rport")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_UserByName(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"user=root"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.User != "root" {
|
||||||
|
t.Errorf("expected user 'root', got %q", filters.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_UserByUID(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"user=1000"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.UID != 1000 {
|
||||||
|
t.Errorf("expected uid 1000, got %d", filters.UID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Laddr(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"laddr=127.0.0.1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Laddr != "127.0.0.1" {
|
||||||
|
t.Errorf("expected laddr '127.0.0.1', got %q", filters.Laddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Raddr(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"raddr=8.8.8.8"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Raddr != "8.8.8.8" {
|
||||||
|
t.Errorf("expected raddr '8.8.8.8', got %q", filters.Raddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Contains(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"contains=google"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Contains != "google" {
|
||||||
|
t.Errorf("expected contains 'google', got %q", filters.Contains)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Interface(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"if=eth0"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Interface != "eth0" {
|
||||||
|
t.Errorf("expected interface 'eth0', got %q", filters.Interface)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test alternative syntax
|
||||||
|
filters2, err := ParseFilterArgs([]string{"interface=lo"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters2.Interface != "lo" {
|
||||||
|
t.Errorf("expected interface 'lo', got %q", filters2.Interface)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Mark(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"mark=0x1234"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Mark != "0x1234" {
|
||||||
|
t.Errorf("expected mark '0x1234', got %q", filters.Mark)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Namespace(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"namespace=default"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Namespace != "default" {
|
||||||
|
t.Errorf("expected namespace 'default', got %q", filters.Namespace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Inode(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"inode=123456"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Inode != 123456 {
|
||||||
|
t.Errorf("expected inode 123456, got %d", filters.Inode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_InvalidInode(t *testing.T) {
|
||||||
|
_, err := ParseFilterArgs([]string{"inode=notanumber"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid inode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_Multiple(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"proto=tcp", "state=LISTEN", "lport=80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "tcp" {
|
||||||
|
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
if filters.State != "LISTEN" {
|
||||||
|
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||||
|
}
|
||||||
|
if filters.Lport != 80 {
|
||||||
|
t.Errorf("expected lport 80, got %d", filters.Lport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_InvalidFormat(t *testing.T) {
|
||||||
|
_, err := ParseFilterArgs([]string{"invalidformat"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_UnknownKey(t *testing.T) {
|
||||||
|
_, err := ParseFilterArgs([]string{"unknownkey=value"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for unknown key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilterArgs_CaseInsensitiveKeys(t *testing.T) {
|
||||||
|
filters, err := ParseFilterArgs([]string{"PROTO=tcp", "State=LISTEN"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "tcp" {
|
||||||
|
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
if filters.State != "LISTEN" {
|
||||||
|
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_TCPOnly(t *testing.T) {
|
||||||
|
// save and restore global flags
|
||||||
|
oldTCP, oldUDP := filterTCP, filterUDP
|
||||||
|
defer func() {
|
||||||
|
filterTCP, filterUDP = oldTCP, oldUDP
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterTCP = true
|
||||||
|
filterUDP = false
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "tcp" {
|
||||||
|
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_UDPOnly(t *testing.T) {
|
||||||
|
oldTCP, oldUDP := filterTCP, filterUDP
|
||||||
|
defer func() {
|
||||||
|
filterTCP, filterUDP = oldTCP, oldUDP
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterTCP = false
|
||||||
|
filterUDP = true
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "udp" {
|
||||||
|
t.Errorf("expected proto 'udp', got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_ListenOnly(t *testing.T) {
|
||||||
|
oldListen, oldEstab := filterListen, filterEstab
|
||||||
|
defer func() {
|
||||||
|
filterListen, filterEstab = oldListen, oldEstab
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterListen = true
|
||||||
|
filterEstab = false
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.State != "LISTEN" {
|
||||||
|
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_EstablishedOnly(t *testing.T) {
|
||||||
|
oldListen, oldEstab := filterListen, filterEstab
|
||||||
|
defer func() {
|
||||||
|
filterListen, filterEstab = oldListen, oldEstab
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterListen = false
|
||||||
|
filterEstab = true
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.State != "ESTABLISHED" {
|
||||||
|
t.Errorf("expected state 'ESTABLISHED', got %q", filters.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_IPv4Flag(t *testing.T) {
|
||||||
|
oldIPv4 := filterIPv4
|
||||||
|
defer func() {
|
||||||
|
filterIPv4 = oldIPv4
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterIPv4 = true
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !filters.IPv4 {
|
||||||
|
t.Error("expected IPv4 to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_IPv6Flag(t *testing.T) {
|
||||||
|
oldIPv6 := filterIPv6
|
||||||
|
defer func() {
|
||||||
|
filterIPv6 = oldIPv6
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterIPv6 = true
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !filters.IPv6 {
|
||||||
|
t.Error("expected IPv6 to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFilters_CombinedArgsAndFlags(t *testing.T) {
|
||||||
|
oldTCP := filterTCP
|
||||||
|
defer func() {
|
||||||
|
filterTCP = oldTCP
|
||||||
|
}()
|
||||||
|
|
||||||
|
filterTCP = true
|
||||||
|
|
||||||
|
filters, err := BuildFilters([]string{"lport=80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if filters.Proto != "tcp" {
|
||||||
|
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||||
|
}
|
||||||
|
if filters.Lport != 80 {
|
||||||
|
t.Errorf("expected lport 80, got %d", filters.Lport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuntime_PreWarmDNS(t *testing.T) {
|
||||||
|
rt := &Runtime{
|
||||||
|
Connections: []collector.Connection{
|
||||||
|
{Laddr: "127.0.0.1", Raddr: "192.168.1.1"},
|
||||||
|
{Laddr: "127.0.0.1", Raddr: "10.0.0.1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// should not panic
|
||||||
|
rt.PreWarmDNS()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuntime_PreWarmDNS_Empty(t *testing.T) {
|
||||||
|
rt := &Runtime{
|
||||||
|
Connections: []collector.Connection{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// should not panic with empty connections
|
||||||
|
rt.PreWarmDNS()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuntime_SortConnections(t *testing.T) {
|
||||||
|
rt := &Runtime{
|
||||||
|
Connections: []collector.Connection{
|
||||||
|
{Lport: 443},
|
||||||
|
{Lport: 80},
|
||||||
|
{Lport: 8080},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rt.SortConnections(collector.SortOptions{
|
||||||
|
Field: collector.SortByLport,
|
||||||
|
Direction: collector.SortAsc,
|
||||||
|
})
|
||||||
|
|
||||||
|
if rt.Connections[0].Lport != 80 {
|
||||||
|
t.Errorf("expected first connection to have lport 80, got %d", rt.Connections[0].Lport)
|
||||||
|
}
|
||||||
|
if rt.Connections[1].Lport != 443 {
|
||||||
|
t.Errorf("expected second connection to have lport 443, got %d", rt.Connections[1].Lport)
|
||||||
|
}
|
||||||
|
if rt.Connections[2].Lport != 8080 {
|
||||||
|
t.Errorf("expected third connection to have lport 8080, got %d", rt.Connections[2].Lport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuntime_SortConnections_Desc(t *testing.T) {
|
||||||
|
rt := &Runtime{
|
||||||
|
Connections: []collector.Connection{
|
||||||
|
{Lport: 80},
|
||||||
|
{Lport: 443},
|
||||||
|
{Lport: 8080},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rt.SortConnections(collector.SortOptions{
|
||||||
|
Field: collector.SortByLport,
|
||||||
|
Direction: collector.SortDesc,
|
||||||
|
})
|
||||||
|
|
||||||
|
if rt.Connections[0].Lport != 8080 {
|
||||||
|
t.Errorf("expected first connection to have lport 8080, got %d", rt.Connections[0].Lport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyFilter_AllKeys(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
validate func(t *testing.T, f *collector.FilterOptions)
|
||||||
|
}{
|
||||||
|
{"proto", "tcp", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Proto != "tcp" {
|
||||||
|
t.Errorf("proto: expected 'tcp', got %q", f.Proto)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"state", "LISTEN", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.State != "LISTEN" {
|
||||||
|
t.Errorf("state: expected 'LISTEN', got %q", f.State)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"pid", "100", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Pid != 100 {
|
||||||
|
t.Errorf("pid: expected 100, got %d", f.Pid)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"proc", "nginx", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Proc != "nginx" {
|
||||||
|
t.Errorf("proc: expected 'nginx', got %q", f.Proc)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"lport", "80", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Lport != 80 {
|
||||||
|
t.Errorf("lport: expected 80, got %d", f.Lport)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"rport", "443", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Rport != 443 {
|
||||||
|
t.Errorf("rport: expected 443, got %d", f.Rport)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"laddr", "127.0.0.1", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Laddr != "127.0.0.1" {
|
||||||
|
t.Errorf("laddr: expected '127.0.0.1', got %q", f.Laddr)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"raddr", "8.8.8.8", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Raddr != "8.8.8.8" {
|
||||||
|
t.Errorf("raddr: expected '8.8.8.8', got %q", f.Raddr)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"contains", "test", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Contains != "test" {
|
||||||
|
t.Errorf("contains: expected 'test', got %q", f.Contains)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"if", "eth0", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Interface != "eth0" {
|
||||||
|
t.Errorf("interface: expected 'eth0', got %q", f.Interface)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"mark", "0xff", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Mark != "0xff" {
|
||||||
|
t.Errorf("mark: expected '0xff', got %q", f.Mark)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"namespace", "ns1", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Namespace != "ns1" {
|
||||||
|
t.Errorf("namespace: expected 'ns1', got %q", f.Namespace)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"inode", "12345", func(t *testing.T, f *collector.FilterOptions) {
|
||||||
|
if f.Inode != 12345 {
|
||||||
|
t.Errorf("inode: expected 12345, got %d", f.Inode)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.key, func(t *testing.T) {
|
||||||
|
filters := &collector.FilterOptions{}
|
||||||
|
err := applyFilter(filters, tt.key, tt.value)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
tt.validate(t, filters)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
17
cmd/top.go
17
cmd/top.go
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/karol-broda/snitch/internal/config"
|
"github.com/karol-broda/snitch/internal/config"
|
||||||
|
"github.com/karol-broda/snitch/internal/resolver"
|
||||||
"github.com/karol-broda/snitch/internal/tui"
|
"github.com/karol-broda/snitch/internal/tui"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@@ -14,8 +15,6 @@ import (
|
|||||||
var (
|
var (
|
||||||
topTheme string
|
topTheme string
|
||||||
topInterval time.Duration
|
topInterval time.Duration
|
||||||
topResolveAddrs bool
|
|
||||||
topResolvePorts bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var topCmd = &cobra.Command{
|
var topCmd = &cobra.Command{
|
||||||
@@ -29,11 +28,16 @@ var topCmd = &cobra.Command{
|
|||||||
theme = cfg.Defaults.Theme
|
theme = cfg.Defaults.Theme
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configure resolver with cache setting
|
||||||
|
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
|
||||||
|
resolver.SetNoCache(effectiveNoCache)
|
||||||
|
|
||||||
opts := tui.Options{
|
opts := tui.Options{
|
||||||
Theme: theme,
|
Theme: theme,
|
||||||
Interval: topInterval,
|
Interval: topInterval,
|
||||||
ResolveAddrs: topResolveAddrs,
|
ResolveAddrs: resolveAddrs,
|
||||||
ResolvePorts: topResolvePorts,
|
ResolvePorts: resolvePorts,
|
||||||
|
NoCache: effectiveNoCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
// if any filter flag is set, use exclusive mode
|
// if any filter flag is set, use exclusive mode
|
||||||
@@ -62,9 +66,8 @@ func init() {
|
|||||||
// top-specific flags
|
// top-specific flags
|
||||||
topCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
|
topCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
|
||||||
topCmd.Flags().DurationVarP(&topInterval, "interval", "i", time.Second, "Refresh interval")
|
topCmd.Flags().DurationVarP(&topInterval, "interval", "i", time.Second, "Refresh interval")
|
||||||
topCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
|
||||||
topCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
|
||||||
|
|
||||||
// shared filter flags
|
// shared flags
|
||||||
addFilterFlags(topCmd)
|
addFilterFlags(topCmd)
|
||||||
|
addResolutionFlags(topCmd)
|
||||||
}
|
}
|
||||||
23
cmd/trace.go
23
cmd/trace.go
@@ -7,12 +7,14 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"github.com/karol-broda/snitch/internal/collector"
|
|
||||||
"github.com/karol-broda/snitch/internal/resolver"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/karol-broda/snitch/internal/collector"
|
||||||
|
"github.com/karol-broda/snitch/internal/config"
|
||||||
|
"github.com/karol-broda/snitch/internal/resolver"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +28,6 @@ var (
|
|||||||
traceInterval time.Duration
|
traceInterval time.Duration
|
||||||
traceCount int
|
traceCount int
|
||||||
traceOutputFormat string
|
traceOutputFormat string
|
||||||
traceNumeric bool
|
|
||||||
traceTimestamp bool
|
traceTimestamp bool
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,6 +48,12 @@ Available filters:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runTraceCommand(args []string) {
|
func runTraceCommand(args []string) {
|
||||||
|
cfg := config.Get()
|
||||||
|
|
||||||
|
// configure resolver with cache setting
|
||||||
|
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
|
||||||
|
resolver.SetNoCache(effectiveNoCache)
|
||||||
|
|
||||||
filters, err := BuildFilters(args)
|
filters, err := BuildFilters(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error parsing filters: %v", err)
|
log.Fatalf("Error parsing filters: %v", err)
|
||||||
@@ -180,14 +187,16 @@ func printTraceEventHuman(event TraceEvent) {
|
|||||||
lportStr := fmt.Sprintf("%d", conn.Lport)
|
lportStr := fmt.Sprintf("%d", conn.Lport)
|
||||||
rportStr := fmt.Sprintf("%d", conn.Rport)
|
rportStr := fmt.Sprintf("%d", conn.Rport)
|
||||||
|
|
||||||
// Handle name resolution based on numeric flag
|
// apply name resolution
|
||||||
if !traceNumeric {
|
if resolveAddrs {
|
||||||
if resolvedLaddr := resolver.ResolveAddr(conn.Laddr); resolvedLaddr != conn.Laddr {
|
if resolvedLaddr := resolver.ResolveAddr(conn.Laddr); resolvedLaddr != conn.Laddr {
|
||||||
laddr = resolvedLaddr
|
laddr = resolvedLaddr
|
||||||
}
|
}
|
||||||
if resolvedRaddr := resolver.ResolveAddr(conn.Raddr); resolvedRaddr != conn.Raddr && conn.Raddr != "*" && conn.Raddr != "" {
|
if resolvedRaddr := resolver.ResolveAddr(conn.Raddr); resolvedRaddr != conn.Raddr && conn.Raddr != "*" && conn.Raddr != "" {
|
||||||
raddr = resolvedRaddr
|
raddr = resolvedRaddr
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if resolvePorts {
|
||||||
if resolvedLport := resolver.ResolvePort(conn.Lport, conn.Proto); resolvedLport != fmt.Sprintf("%d", conn.Lport) {
|
if resolvedLport := resolver.ResolvePort(conn.Lport, conn.Proto); resolvedLport != fmt.Sprintf("%d", conn.Lport) {
|
||||||
lportStr = resolvedLport
|
lportStr = resolvedLport
|
||||||
}
|
}
|
||||||
@@ -225,9 +234,9 @@ func init() {
|
|||||||
traceCmd.Flags().DurationVarP(&traceInterval, "interval", "i", time.Second, "Polling interval (e.g., 500ms, 2s)")
|
traceCmd.Flags().DurationVarP(&traceInterval, "interval", "i", time.Second, "Polling interval (e.g., 500ms, 2s)")
|
||||||
traceCmd.Flags().IntVarP(&traceCount, "count", "c", 0, "Number of events to capture (0 = unlimited)")
|
traceCmd.Flags().IntVarP(&traceCount, "count", "c", 0, "Number of events to capture (0 = unlimited)")
|
||||||
traceCmd.Flags().StringVarP(&traceOutputFormat, "output", "o", "human", "Output format (human, json)")
|
traceCmd.Flags().StringVarP(&traceOutputFormat, "output", "o", "human", "Output format (human, json)")
|
||||||
traceCmd.Flags().BoolVarP(&traceNumeric, "numeric", "n", false, "Don't resolve hostnames")
|
|
||||||
traceCmd.Flags().BoolVar(&traceTimestamp, "ts", false, "Include timestamp in output")
|
traceCmd.Flags().BoolVar(&traceTimestamp, "ts", false, "Include timestamp in output")
|
||||||
|
|
||||||
// shared filter flags
|
// shared flags
|
||||||
addFilterFlags(traceCmd)
|
addFilterFlags(traceCmd)
|
||||||
|
addResolutionFlags(traceCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type DefaultConfig struct {
|
|||||||
Units string `mapstructure:"units"`
|
Units string `mapstructure:"units"`
|
||||||
Color string `mapstructure:"color"`
|
Color string `mapstructure:"color"`
|
||||||
Resolve bool `mapstructure:"resolve"`
|
Resolve bool `mapstructure:"resolve"`
|
||||||
|
DNSCache bool `mapstructure:"dns_cache"`
|
||||||
IPv4 bool `mapstructure:"ipv4"`
|
IPv4 bool `mapstructure:"ipv4"`
|
||||||
IPv6 bool `mapstructure:"ipv6"`
|
IPv6 bool `mapstructure:"ipv6"`
|
||||||
NoHeaders bool `mapstructure:"no_headers"`
|
NoHeaders bool `mapstructure:"no_headers"`
|
||||||
@@ -57,6 +58,7 @@ func Load() (*Config, error) {
|
|||||||
// environment variable bindings for readme-documented variables
|
// environment variable bindings for readme-documented variables
|
||||||
_ = v.BindEnv("config", "SNITCH_CONFIG")
|
_ = v.BindEnv("config", "SNITCH_CONFIG")
|
||||||
_ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE")
|
_ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE")
|
||||||
|
_ = v.BindEnv("defaults.dns_cache", "SNITCH_DNS_CACHE")
|
||||||
_ = v.BindEnv("defaults.theme", "SNITCH_THEME")
|
_ = v.BindEnv("defaults.theme", "SNITCH_THEME")
|
||||||
_ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR")
|
_ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR")
|
||||||
|
|
||||||
@@ -90,7 +92,6 @@ func Load() (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setDefaults(v *viper.Viper) {
|
func setDefaults(v *viper.Viper) {
|
||||||
// Set default values matching the README specification
|
|
||||||
v.SetDefault("defaults.interval", "1s")
|
v.SetDefault("defaults.interval", "1s")
|
||||||
v.SetDefault("defaults.numeric", false)
|
v.SetDefault("defaults.numeric", false)
|
||||||
v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"})
|
v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"})
|
||||||
@@ -98,6 +99,7 @@ func setDefaults(v *viper.Viper) {
|
|||||||
v.SetDefault("defaults.units", "auto")
|
v.SetDefault("defaults.units", "auto")
|
||||||
v.SetDefault("defaults.color", "auto")
|
v.SetDefault("defaults.color", "auto")
|
||||||
v.SetDefault("defaults.resolve", true)
|
v.SetDefault("defaults.resolve", true)
|
||||||
|
v.SetDefault("defaults.dns_cache", true)
|
||||||
v.SetDefault("defaults.ipv4", false)
|
v.SetDefault("defaults.ipv4", false)
|
||||||
v.SetDefault("defaults.ipv6", false)
|
v.SetDefault("defaults.ipv6", false)
|
||||||
v.SetDefault("defaults.no_headers", false)
|
v.SetDefault("defaults.no_headers", false)
|
||||||
@@ -116,6 +118,11 @@ func handleSpecialEnvVars(v *viper.Viper) {
|
|||||||
v.Set("defaults.resolve", false)
|
v.Set("defaults.resolve", false)
|
||||||
v.Set("defaults.numeric", true)
|
v.Set("defaults.numeric", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle SNITCH_DNS_CACHE - if set to "0", disable dns caching
|
||||||
|
if os.Getenv("SNITCH_DNS_CACHE") == "0" {
|
||||||
|
v.Set("defaults.dns_cache", false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the global configuration, loading it if necessary
|
// Get returns the global configuration, loading it if necessary
|
||||||
@@ -123,7 +130,6 @@ func Get() *Config {
|
|||||||
if globalConfig == nil {
|
if globalConfig == nil {
|
||||||
config, err := Load()
|
config, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return default config on error
|
|
||||||
return &Config{
|
return &Config{
|
||||||
Defaults: DefaultConfig{
|
Defaults: DefaultConfig{
|
||||||
Interval: "1s",
|
Interval: "1s",
|
||||||
@@ -133,6 +139,7 @@ func Get() *Config {
|
|||||||
Units: "auto",
|
Units: "auto",
|
||||||
Color: "auto",
|
Color: "auto",
|
||||||
Resolve: true,
|
Resolve: true,
|
||||||
|
DNSCache: true,
|
||||||
IPv4: false,
|
IPv4: false,
|
||||||
IPv6: false,
|
IPv6: false,
|
||||||
NoHeaders: false,
|
NoHeaders: false,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type Resolver struct {
|
|||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cache map[string]string
|
cache map[string]string
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
noCache bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new resolver with the specified timeout
|
// New creates a new resolver with the specified timeout
|
||||||
@@ -24,37 +25,44 @@ func New(timeout time.Duration) *Resolver {
|
|||||||
return &Resolver{
|
return &Resolver{
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cache: make(map[string]string),
|
cache: make(map[string]string),
|
||||||
|
noCache: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNoCache disables caching - each lookup will hit DNS directly
|
||||||
|
func (r *Resolver) SetNoCache(noCache bool) {
|
||||||
|
r.noCache = noCache
|
||||||
|
}
|
||||||
|
|
||||||
// ResolveAddr resolves an IP address to a hostname, with caching
|
// ResolveAddr resolves an IP address to a hostname, with caching
|
||||||
func (r *Resolver) ResolveAddr(addr string) string {
|
func (r *Resolver) ResolveAddr(addr string) string {
|
||||||
// Check cache first
|
// check cache first (unless caching is disabled)
|
||||||
|
if !r.noCache {
|
||||||
r.mutex.RLock()
|
r.mutex.RLock()
|
||||||
if cached, exists := r.cache[addr]; exists {
|
if cached, exists := r.cache[addr]; exists {
|
||||||
r.mutex.RUnlock()
|
r.mutex.RUnlock()
|
||||||
return cached
|
return cached
|
||||||
}
|
}
|
||||||
r.mutex.RUnlock()
|
r.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Parse IP to validate it
|
// parse ip to validate it
|
||||||
ip := net.ParseIP(addr)
|
ip := net.ParseIP(addr)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
// Not a valid IP, return as-is
|
|
||||||
return addr
|
return addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform resolution with timeout
|
// perform resolution with timeout
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
|
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
|
||||||
|
|
||||||
resolved := addr // fallback to original address
|
resolved := addr
|
||||||
if err == nil && len(names) > 0 {
|
if err == nil && len(names) > 0 {
|
||||||
resolved = names[0]
|
resolved = names[0]
|
||||||
// Remove trailing dot if present
|
// remove trailing dot if present
|
||||||
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
|
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
|
||||||
resolved = resolved[:len(resolved)-1]
|
resolved = resolved[:len(resolved)-1]
|
||||||
}
|
}
|
||||||
@@ -65,10 +73,12 @@ func (r *Resolver) ResolveAddr(addr string) string {
|
|||||||
fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed)
|
fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the result
|
// cache the result (unless caching is disabled)
|
||||||
|
if !r.noCache {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
r.cache[addr] = resolved
|
r.cache[addr] = resolved
|
||||||
r.mutex.Unlock()
|
r.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return resolved
|
return resolved
|
||||||
}
|
}
|
||||||
@@ -81,15 +91,17 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
|
|||||||
|
|
||||||
cacheKey := strconv.Itoa(port) + "/" + proto
|
cacheKey := strconv.Itoa(port) + "/" + proto
|
||||||
|
|
||||||
// Check cache first
|
// check cache first (unless caching is disabled)
|
||||||
|
if !r.noCache {
|
||||||
r.mutex.RLock()
|
r.mutex.RLock()
|
||||||
if cached, exists := r.cache[cacheKey]; exists {
|
if cached, exists := r.cache[cacheKey]; exists {
|
||||||
r.mutex.RUnlock()
|
r.mutex.RUnlock()
|
||||||
return cached
|
return cached
|
||||||
}
|
}
|
||||||
r.mutex.RUnlock()
|
r.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Perform resolution with timeout
|
// perform resolution with timeout
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -97,16 +109,18 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
|
|||||||
|
|
||||||
resolved := strconv.Itoa(port) // fallback to port number
|
resolved := strconv.Itoa(port) // fallback to port number
|
||||||
if err == nil && service != 0 {
|
if err == nil && service != 0 {
|
||||||
// Try to get service name
|
// try to get service name
|
||||||
if serviceName := getServiceName(port, proto); serviceName != "" {
|
if serviceName := getServiceName(port, proto); serviceName != "" {
|
||||||
resolved = serviceName
|
resolved = serviceName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the result
|
// cache the result (unless caching is disabled)
|
||||||
|
if !r.noCache {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
r.cache[cacheKey] = resolved
|
r.cache[cacheKey] = resolved
|
||||||
r.mutex.Unlock()
|
r.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return resolved
|
return resolved
|
||||||
}
|
}
|
||||||
@@ -169,22 +183,38 @@ func getServiceName(port int, proto string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global resolver instance
|
// global resolver instance
|
||||||
var globalResolver *Resolver
|
var globalResolver *Resolver
|
||||||
|
|
||||||
// SetGlobalResolver sets the global resolver instance
|
// ResolverOptions configures the global resolver
|
||||||
func SetGlobalResolver(timeout time.Duration) {
|
type ResolverOptions struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
NoCache bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGlobalResolver sets the global resolver instance with options
|
||||||
|
func SetGlobalResolver(opts ResolverOptions) {
|
||||||
|
timeout := opts.Timeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 200 * time.Millisecond
|
||||||
|
}
|
||||||
globalResolver = New(timeout)
|
globalResolver = New(timeout)
|
||||||
|
globalResolver.SetNoCache(opts.NoCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGlobalResolver returns the global resolver instance
|
// GetGlobalResolver returns the global resolver instance
|
||||||
func GetGlobalResolver() *Resolver {
|
func GetGlobalResolver() *Resolver {
|
||||||
if globalResolver == nil {
|
if globalResolver == nil {
|
||||||
globalResolver = New(200 * time.Millisecond) // Default timeout
|
globalResolver = New(200 * time.Millisecond)
|
||||||
}
|
}
|
||||||
return globalResolver
|
return globalResolver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNoCache configures whether the global resolver bypasses cache
|
||||||
|
func SetNoCache(noCache bool) {
|
||||||
|
GetGlobalResolver().SetNoCache(noCache)
|
||||||
|
}
|
||||||
|
|
||||||
// ResolveAddr is a convenience function using the global resolver
|
// ResolveAddr is a convenience function using the global resolver
|
||||||
func ResolveAddr(addr string) string {
|
func ResolveAddr(addr string) string {
|
||||||
return GetGlobalResolver().ResolveAddr(addr)
|
return GetGlobalResolver().ResolveAddr(addr)
|
||||||
@@ -199,3 +229,48 @@ func ResolvePort(port int, proto string) string {
|
|||||||
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
|
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
|
||||||
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
|
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveAddrsParallel resolves multiple addresses concurrently and caches results.
|
||||||
|
// This should be called before rendering to pre-warm the cache.
|
||||||
|
func (r *Resolver) ResolveAddrsParallel(addrs []string) {
|
||||||
|
// dedupe and filter addresses that need resolution
|
||||||
|
unique := make(map[string]struct{})
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if addr == "" || addr == "*" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// skip if already cached
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, exists := r.cache[addr]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
if exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
unique[addr] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unique) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
// limit concurrency to avoid overwhelming dns
|
||||||
|
sem := make(chan struct{}, 32)
|
||||||
|
|
||||||
|
for addr := range unique {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(a string) {
|
||||||
|
defer wg.Done()
|
||||||
|
sem <- struct{}{}
|
||||||
|
defer func() { <-sem }()
|
||||||
|
r.ResolveAddr(a)
|
||||||
|
}(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveAddrsParallel is a convenience function using the global resolver
|
||||||
|
func ResolveAddrsParallel(addrs []string) {
|
||||||
|
GetGlobalResolver().ResolveAddrsParallel(addrs)
|
||||||
|
}
|
||||||
|
|||||||
159
internal/resolver/resolver_bench_test.go
Normal file
159
internal/resolver/resolver_bench_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkResolveAddr_CacheHit(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
|
||||||
|
// pre-populate cache
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolveAddr_CacheMiss(b *testing.B) {
|
||||||
|
r := New(10 * time.Millisecond) // short timeout for faster benchmarks
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// use different addresses to avoid cache hits
|
||||||
|
addr := fmt.Sprintf("127.0.0.%d", i%256)
|
||||||
|
r.ClearCache() // clear cache to force miss
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolveAddr_NoCache(b *testing.B) {
|
||||||
|
r := New(10 * time.Millisecond)
|
||||||
|
r.SetNoCache(true)
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolvePort_CacheHit(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// pre-populate cache
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolvePort_WellKnown(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.ClearCache()
|
||||||
|
r.ResolvePort(443, "tcp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetServiceName(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
getServiceName(80, "tcp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetServiceName_NotFound(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
getServiceName(12345, "tcp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolveAddrsParallel_10(b *testing.B) {
|
||||||
|
benchmarkResolveAddrsParallel(b, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolveAddrsParallel_100(b *testing.B) {
|
||||||
|
benchmarkResolveAddrsParallel(b, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResolveAddrsParallel_1000(b *testing.B) {
|
||||||
|
benchmarkResolveAddrsParallel(b, 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkResolveAddrsParallel(b *testing.B, count int) {
|
||||||
|
addrs := make([]string, count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
addrs[i] = fmt.Sprintf("127.0.%d.%d", i/256, i%256)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r := New(10 * time.Millisecond)
|
||||||
|
r.ResolveAddrsParallel(addrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConcurrentResolveAddr(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
|
||||||
|
// pre-populate cache
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConcurrentResolvePort(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// pre-populate cache
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetCacheSize(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// populate with some entries
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
r.ResolvePort(i+1, "tcp")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.GetCacheSize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkClearCache(b *testing.B) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// populate and clear
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
r.ResolvePort(j+1, "tcp")
|
||||||
|
}
|
||||||
|
r.ClearCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
387
internal/resolver/resolver_test.go
Normal file
387
internal/resolver/resolver_test.go
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("expected non-nil resolver")
|
||||||
|
}
|
||||||
|
if r.timeout != 100*time.Millisecond {
|
||||||
|
t.Errorf("expected timeout 100ms, got %v", r.timeout)
|
||||||
|
}
|
||||||
|
if r.cache == nil {
|
||||||
|
t.Error("expected cache to be initialized")
|
||||||
|
}
|
||||||
|
if r.noCache {
|
||||||
|
t.Error("expected noCache to be false by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetNoCache(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
r.SetNoCache(true)
|
||||||
|
if !r.noCache {
|
||||||
|
t.Error("expected noCache to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.SetNoCache(false)
|
||||||
|
if r.noCache {
|
||||||
|
t.Error("expected noCache to be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddr_InvalidIP(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// invalid ip should return as-is
|
||||||
|
result := r.ResolveAddr("not-an-ip")
|
||||||
|
if result != "not-an-ip" {
|
||||||
|
t.Errorf("expected 'not-an-ip', got %q", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// empty string should return as-is
|
||||||
|
result = r.ResolveAddr("")
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("expected empty string, got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddr_Caching(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// first call should cache
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
result1 := r.ResolveAddr(addr)
|
||||||
|
|
||||||
|
// verify cache is populated
|
||||||
|
if r.GetCacheSize() != 1 {
|
||||||
|
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// second call should use cache
|
||||||
|
result2 := r.ResolveAddr(addr)
|
||||||
|
if result1 != result2 {
|
||||||
|
t.Errorf("expected same result from cache, got %q and %q", result1, result2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddr_NoCacheMode(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
r.SetNoCache(true)
|
||||||
|
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
|
||||||
|
// cache should remain empty when noCache is enabled
|
||||||
|
if r.GetCacheSize() != 0 {
|
||||||
|
t.Errorf("expected cache size 0 with noCache, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePort_Zero(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
result := r.ResolvePort(0, "tcp")
|
||||||
|
if result != "0" {
|
||||||
|
t.Errorf("expected '0' for port 0, got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePort_WellKnown(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
port int
|
||||||
|
proto string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{80, "tcp", "http"},
|
||||||
|
{443, "tcp", "https"},
|
||||||
|
{22, "tcp", "ssh"},
|
||||||
|
{53, "udp", "domain"},
|
||||||
|
{5432, "tcp", "postgresql"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := r.ResolvePort(tt.port, tt.proto)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ResolvePort(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePort_Caching(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
r.ResolvePort(443, "tcp")
|
||||||
|
|
||||||
|
if r.GetCacheSize() != 2 {
|
||||||
|
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// same port/proto should not add new entry
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
if r.GetCacheSize() != 2 {
|
||||||
|
t.Errorf("expected cache size still 2, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddrPort(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
addr, port := r.ResolveAddrPort("127.0.0.1", 80, "tcp")
|
||||||
|
|
||||||
|
if addr == "" {
|
||||||
|
t.Error("expected non-empty address")
|
||||||
|
}
|
||||||
|
if port != "http" {
|
||||||
|
t.Errorf("expected port 'http', got %q", port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCache(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
r.ResolveAddr("127.0.0.1")
|
||||||
|
r.ResolvePort(80, "tcp")
|
||||||
|
|
||||||
|
if r.GetCacheSize() == 0 {
|
||||||
|
t.Error("expected non-empty cache before clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ClearCache()
|
||||||
|
|
||||||
|
if r.GetCacheSize() != 0 {
|
||||||
|
t.Errorf("expected empty cache after clear, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCacheSize(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
if r.GetCacheSize() != 0 {
|
||||||
|
t.Errorf("expected initial cache size 0, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ResolveAddr("127.0.0.1")
|
||||||
|
if r.GetCacheSize() != 1 {
|
||||||
|
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetServiceName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
port int
|
||||||
|
proto string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{80, "tcp", "http"},
|
||||||
|
{443, "tcp", "https"},
|
||||||
|
{22, "tcp", "ssh"},
|
||||||
|
{53, "tcp", "domain"},
|
||||||
|
{53, "udp", "domain"},
|
||||||
|
{12345, "tcp", ""},
|
||||||
|
{0, "tcp", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := getServiceName(tt.port, tt.proto)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("getServiceName(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddrsParallel(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
addrs := []string{
|
||||||
|
"127.0.0.1",
|
||||||
|
"127.0.0.2",
|
||||||
|
"127.0.0.3",
|
||||||
|
"", // should be skipped
|
||||||
|
"*", // should be skipped
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ResolveAddrsParallel(addrs)
|
||||||
|
|
||||||
|
// should have cached 3 addresses (excluding empty and *)
|
||||||
|
if r.GetCacheSize() != 3 {
|
||||||
|
t.Errorf("expected cache size 3, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddrsParallel_Dedupe(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
addrs := []string{
|
||||||
|
"127.0.0.1",
|
||||||
|
"127.0.0.1",
|
||||||
|
"127.0.0.1",
|
||||||
|
"127.0.0.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ResolveAddrsParallel(addrs)
|
||||||
|
|
||||||
|
// should have cached 2 unique addresses
|
||||||
|
if r.GetCacheSize() != 2 {
|
||||||
|
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddrsParallel_SkipsCached(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// pre-cache one address
|
||||||
|
r.ResolveAddr("127.0.0.1")
|
||||||
|
|
||||||
|
addrs := []string{
|
||||||
|
"127.0.0.1", // already cached
|
||||||
|
"127.0.0.2", // not cached
|
||||||
|
}
|
||||||
|
|
||||||
|
initialSize := r.GetCacheSize()
|
||||||
|
r.ResolveAddrsParallel(addrs)
|
||||||
|
|
||||||
|
// should have added 1 more
|
||||||
|
if r.GetCacheSize() != initialSize+1 {
|
||||||
|
t.Errorf("expected cache size %d, got %d", initialSize+1, r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddrsParallel_Empty(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// should not panic with empty input
|
||||||
|
r.ResolveAddrsParallel([]string{})
|
||||||
|
r.ResolveAddrsParallel(nil)
|
||||||
|
|
||||||
|
if r.GetCacheSize() != 0 {
|
||||||
|
t.Errorf("expected cache size 0, got %d", r.GetCacheSize())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalResolver(t *testing.T) {
|
||||||
|
// reset global resolver
|
||||||
|
globalResolver = nil
|
||||||
|
|
||||||
|
r := GetGlobalResolver()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("expected non-nil global resolver")
|
||||||
|
}
|
||||||
|
|
||||||
|
// should return same instance
|
||||||
|
r2 := GetGlobalResolver()
|
||||||
|
if r != r2 {
|
||||||
|
t.Error("expected same global resolver instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetGlobalResolver(t *testing.T) {
|
||||||
|
SetGlobalResolver(ResolverOptions{
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
NoCache: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
r := GetGlobalResolver()
|
||||||
|
if r.timeout != 500*time.Millisecond {
|
||||||
|
t.Errorf("expected timeout 500ms, got %v", r.timeout)
|
||||||
|
}
|
||||||
|
if !r.noCache {
|
||||||
|
t.Error("expected noCache to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset for other tests
|
||||||
|
globalResolver = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetGlobalResolver_DefaultTimeout(t *testing.T) {
|
||||||
|
SetGlobalResolver(ResolverOptions{
|
||||||
|
Timeout: 0, // should use default
|
||||||
|
})
|
||||||
|
|
||||||
|
r := GetGlobalResolver()
|
||||||
|
if r.timeout != 200*time.Millisecond {
|
||||||
|
t.Errorf("expected default timeout 200ms, got %v", r.timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset for other tests
|
||||||
|
globalResolver = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalConvenienceFunctions(t *testing.T) {
|
||||||
|
globalResolver = nil
|
||||||
|
|
||||||
|
// test global ResolveAddr
|
||||||
|
result := ResolveAddr("127.0.0.1")
|
||||||
|
if result == "" {
|
||||||
|
t.Error("expected non-empty result from global ResolveAddr")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test global ResolvePort
|
||||||
|
port := ResolvePort(80, "tcp")
|
||||||
|
if port != "http" {
|
||||||
|
t.Errorf("expected 'http', got %q", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test global ResolveAddrPort
|
||||||
|
addr, portStr := ResolveAddrPort("127.0.0.1", 443, "tcp")
|
||||||
|
if addr == "" {
|
||||||
|
t.Error("expected non-empty address")
|
||||||
|
}
|
||||||
|
if portStr != "https" {
|
||||||
|
t.Errorf("expected 'https', got %q", portStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test global SetNoCache
|
||||||
|
SetNoCache(true)
|
||||||
|
if !GetGlobalResolver().noCache {
|
||||||
|
t.Error("expected global noCache to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset
|
||||||
|
globalResolver = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
r.ResolveAddr(addr)
|
||||||
|
r.ResolvePort(80+n%10, "tcp")
|
||||||
|
r.GetCacheSize()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// should not panic and cache should have entries
|
||||||
|
if r.GetCacheSize() == 0 {
|
||||||
|
t.Error("expected non-empty cache after concurrent access")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAddr_TrailingDot(t *testing.T) {
|
||||||
|
// this test verifies the trailing dot removal logic
|
||||||
|
// by checking the internal logic works correctly
|
||||||
|
r := New(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// localhost should resolve and have trailing dot removed
|
||||||
|
result := r.ResolveAddr("127.0.0.1")
|
||||||
|
if len(result) > 0 && result[len(result)-1] == '.' {
|
||||||
|
t.Error("expected trailing dot to be removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -3,6 +3,7 @@ package tui
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/karol-broda/snitch/internal/collector"
|
"github.com/karol-broda/snitch/internal/collector"
|
||||||
|
"github.com/karol-broda/snitch/internal/resolver"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -35,11 +36,20 @@ func (m model) tick() tea.Cmd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m model) fetchData() tea.Cmd {
|
func (m model) fetchData() tea.Cmd {
|
||||||
|
resolveAddrs := m.resolveAddrs
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
conns, err := collector.GetConnections()
|
conns, err := collector.GetConnections()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errMsg{err}
|
return errMsg{err}
|
||||||
}
|
}
|
||||||
|
// pre-warm dns cache in parallel if resolution is enabled
|
||||||
|
if resolveAddrs {
|
||||||
|
addrs := make([]string, 0, len(conns)*2)
|
||||||
|
for _, c := range conns {
|
||||||
|
addrs = append(addrs, c.Laddr, c.Raddr)
|
||||||
|
}
|
||||||
|
resolver.ResolveAddrsParallel(addrs)
|
||||||
|
}
|
||||||
return dataMsg{connections: conns}
|
return dataMsg{connections: conns}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ type Options struct {
|
|||||||
FilterSet bool // true if user specified any filter flags
|
FilterSet bool // true if user specified any filter flags
|
||||||
ResolveAddrs bool // when true, resolve IP addresses to hostnames
|
ResolveAddrs bool // when true, resolve IP addresses to hostnames
|
||||||
ResolvePorts bool // when true, resolve port numbers to service names
|
ResolvePorts bool // when true, resolve port numbers to service names
|
||||||
|
NoCache bool // when true, disable DNS caching
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(opts Options) model {
|
func New(opts Options) model {
|
||||||
|
|||||||
Reference in New Issue
Block a user