fix: dns resolution taking long and add caching options (#8)
This commit is contained in:
@@ -30,8 +30,6 @@ var (
|
||||
sortBy string
|
||||
fields string
|
||||
colorMode string
|
||||
resolveAddrs bool
|
||||
resolvePorts bool
|
||||
plainOutput bool
|
||||
)
|
||||
|
||||
@@ -400,10 +398,9 @@ func init() {
|
||||
lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)")
|
||||
lsCmd.Flags().StringVarP(&fields, "fields", "f", strings.Join(cfg.Defaults.Fields, ","), "Comma-separated list of fields to show")
|
||||
lsCmd.Flags().StringVar(&colorMode, "color", cfg.Defaults.Color, "Color mode (auto, always, never)")
|
||||
lsCmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
||||
lsCmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
||||
lsCmd.Flags().BoolVarP(&plainOutput, "plain", "p", false, "Plain output (parsable, no styling)")
|
||||
|
||||
// shared filter flags
|
||||
// shared flags
|
||||
addFilterFlags(lsCmd)
|
||||
addResolutionFlags(lsCmd)
|
||||
}
|
||||
@@ -45,9 +45,8 @@ func init() {
|
||||
cfg := config.Get()
|
||||
rootCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
|
||||
rootCmd.Flags().DurationVarP(&topInterval, "interval", "i", 0, "Refresh interval (default 1s)")
|
||||
rootCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
||||
rootCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
||||
|
||||
// shared filter flags for root command
|
||||
// shared flags for root command
|
||||
addFilterFlags(rootCmd)
|
||||
addResolutionFlags(rootCmd)
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"github.com/karol-broda/snitch/internal/collector"
|
||||
"github.com/karol-broda/snitch/internal/color"
|
||||
"github.com/karol-broda/snitch/internal/config"
|
||||
"github.com/karol-broda/snitch/internal/resolver"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -11,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
// Runtime holds the shared state for all commands.
|
||||
// it handles common filter logic, fetching, and filtering connections.
|
||||
// it handles common filter logic, fetching, filtering, and resolution.
|
||||
type Runtime struct {
|
||||
// filter options built from flags and args
|
||||
Filters collector.FilterOptions
|
||||
@@ -23,6 +25,7 @@ type Runtime struct {
|
||||
ColorMode string
|
||||
ResolveAddrs bool
|
||||
ResolvePorts bool
|
||||
NoCache bool
|
||||
}
|
||||
|
||||
// shared filter flags - used by all commands
|
||||
@@ -35,6 +38,13 @@ var (
|
||||
filterIPv6 bool
|
||||
)
|
||||
|
||||
// shared resolution flags - used by all commands
|
||||
var (
|
||||
resolveAddrs bool
|
||||
resolvePorts bool
|
||||
noCache bool
|
||||
)
|
||||
|
||||
// BuildFilters constructs FilterOptions from command args and shortcut flags.
|
||||
func BuildFilters(args []string) (collector.FilterOptions, error) {
|
||||
filters, err := ParseFilterArgs(args)
|
||||
@@ -77,6 +87,12 @@ func FetchConnections(filters collector.FilterOptions) ([]collector.Connection,
|
||||
func NewRuntime(args []string, colorMode string) (*Runtime, error) {
|
||||
color.Init(colorMode)
|
||||
|
||||
cfg := config.Get()
|
||||
|
||||
// configure resolver with cache setting (flag overrides config)
|
||||
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
|
||||
resolver.SetNoCache(effectiveNoCache)
|
||||
|
||||
filters, err := BuildFilters(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
||||
@@ -87,13 +103,30 @@ func NewRuntime(args []string, colorMode string) (*Runtime, error) {
|
||||
return nil, fmt.Errorf("failed to fetch connections: %w", err)
|
||||
}
|
||||
|
||||
return &Runtime{
|
||||
rt := &Runtime{
|
||||
Filters: filters,
|
||||
Connections: connections,
|
||||
ColorMode: colorMode,
|
||||
ResolveAddrs: resolveAddrs,
|
||||
ResolvePorts: resolvePorts,
|
||||
}, nil
|
||||
NoCache: effectiveNoCache,
|
||||
}
|
||||
|
||||
// pre-warm dns cache by resolving all addresses in parallel
|
||||
if resolveAddrs {
|
||||
rt.PreWarmDNS()
|
||||
}
|
||||
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
// PreWarmDNS resolves all connection addresses in parallel to warm the cache.
|
||||
func (r *Runtime) PreWarmDNS() {
|
||||
addrs := make([]string, 0, len(r.Connections)*2)
|
||||
for _, c := range r.Connections {
|
||||
addrs = append(addrs, c.Laddr, c.Raddr)
|
||||
}
|
||||
resolver.ResolveAddrsParallel(addrs)
|
||||
}
|
||||
|
||||
// SortConnections sorts the runtime's connections in place.
|
||||
@@ -201,3 +234,11 @@ func addFilterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().BoolVarP(&filterIPv6, "ipv6", "6", false, "Only show IPv6 connections")
|
||||
}
|
||||
|
||||
// addResolutionFlags adds the common resolution flags to a command.
|
||||
func addResolutionFlags(cmd *cobra.Command) {
|
||||
cfg := config.Get()
|
||||
cmd.Flags().BoolVar(&resolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
||||
cmd.Flags().BoolVar(&resolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
||||
cmd.Flags().BoolVar(&noCache, "no-cache", !cfg.Defaults.DNSCache, "Disable DNS caching (force fresh lookups)")
|
||||
}
|
||||
|
||||
|
||||
525
cmd/runtime_test.go
Normal file
525
cmd/runtime_test.go
Normal file
@@ -0,0 +1,525 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/karol-broda/snitch/internal/collector"
|
||||
)
|
||||
|
||||
func TestParseFilterArgs_Empty(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "" {
|
||||
t.Errorf("expected empty proto, got %q", filters.Proto)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Proto(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"proto=tcp"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "tcp" {
|
||||
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_State(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"state=LISTEN"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.State != "LISTEN" {
|
||||
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_PID(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"pid=1234"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Pid != 1234 {
|
||||
t.Errorf("expected pid 1234, got %d", filters.Pid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_InvalidPID(t *testing.T) {
|
||||
_, err := ParseFilterArgs([]string{"pid=notanumber"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid pid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Proc(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"proc=nginx"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proc != "nginx" {
|
||||
t.Errorf("expected proc 'nginx', got %q", filters.Proc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Lport(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"lport=80"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Lport != 80 {
|
||||
t.Errorf("expected lport 80, got %d", filters.Lport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_InvalidLport(t *testing.T) {
|
||||
_, err := ParseFilterArgs([]string{"lport=notaport"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid lport")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Rport(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"rport=443"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Rport != 443 {
|
||||
t.Errorf("expected rport 443, got %d", filters.Rport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_InvalidRport(t *testing.T) {
|
||||
_, err := ParseFilterArgs([]string{"rport=invalid"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid rport")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_UserByName(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"user=root"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.User != "root" {
|
||||
t.Errorf("expected user 'root', got %q", filters.User)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_UserByUID(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"user=1000"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.UID != 1000 {
|
||||
t.Errorf("expected uid 1000, got %d", filters.UID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Laddr(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"laddr=127.0.0.1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Laddr != "127.0.0.1" {
|
||||
t.Errorf("expected laddr '127.0.0.1', got %q", filters.Laddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Raddr(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"raddr=8.8.8.8"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Raddr != "8.8.8.8" {
|
||||
t.Errorf("expected raddr '8.8.8.8', got %q", filters.Raddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Contains(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"contains=google"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Contains != "google" {
|
||||
t.Errorf("expected contains 'google', got %q", filters.Contains)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Interface(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"if=eth0"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Interface != "eth0" {
|
||||
t.Errorf("expected interface 'eth0', got %q", filters.Interface)
|
||||
}
|
||||
|
||||
// test alternative syntax
|
||||
filters2, err := ParseFilterArgs([]string{"interface=lo"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters2.Interface != "lo" {
|
||||
t.Errorf("expected interface 'lo', got %q", filters2.Interface)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Mark(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"mark=0x1234"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Mark != "0x1234" {
|
||||
t.Errorf("expected mark '0x1234', got %q", filters.Mark)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Namespace(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"namespace=default"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Namespace != "default" {
|
||||
t.Errorf("expected namespace 'default', got %q", filters.Namespace)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Inode(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"inode=123456"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Inode != 123456 {
|
||||
t.Errorf("expected inode 123456, got %d", filters.Inode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_InvalidInode(t *testing.T) {
|
||||
_, err := ParseFilterArgs([]string{"inode=notanumber"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid inode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_Multiple(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"proto=tcp", "state=LISTEN", "lport=80"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "tcp" {
|
||||
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||
}
|
||||
if filters.State != "LISTEN" {
|
||||
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||
}
|
||||
if filters.Lport != 80 {
|
||||
t.Errorf("expected lport 80, got %d", filters.Lport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_InvalidFormat(t *testing.T) {
|
||||
_, err := ParseFilterArgs([]string{"invalidformat"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid format")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_UnknownKey(t *testing.T) {
|
||||
_, err := ParseFilterArgs([]string{"unknownkey=value"})
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilterArgs_CaseInsensitiveKeys(t *testing.T) {
|
||||
filters, err := ParseFilterArgs([]string{"PROTO=tcp", "State=LISTEN"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "tcp" {
|
||||
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||
}
|
||||
if filters.State != "LISTEN" {
|
||||
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_TCPOnly(t *testing.T) {
|
||||
// save and restore global flags
|
||||
oldTCP, oldUDP := filterTCP, filterUDP
|
||||
defer func() {
|
||||
filterTCP, filterUDP = oldTCP, oldUDP
|
||||
}()
|
||||
|
||||
filterTCP = true
|
||||
filterUDP = false
|
||||
|
||||
filters, err := BuildFilters([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "tcp" {
|
||||
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_UDPOnly(t *testing.T) {
|
||||
oldTCP, oldUDP := filterTCP, filterUDP
|
||||
defer func() {
|
||||
filterTCP, filterUDP = oldTCP, oldUDP
|
||||
}()
|
||||
|
||||
filterTCP = false
|
||||
filterUDP = true
|
||||
|
||||
filters, err := BuildFilters([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "udp" {
|
||||
t.Errorf("expected proto 'udp', got %q", filters.Proto)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_ListenOnly(t *testing.T) {
|
||||
oldListen, oldEstab := filterListen, filterEstab
|
||||
defer func() {
|
||||
filterListen, filterEstab = oldListen, oldEstab
|
||||
}()
|
||||
|
||||
filterListen = true
|
||||
filterEstab = false
|
||||
|
||||
filters, err := BuildFilters([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.State != "LISTEN" {
|
||||
t.Errorf("expected state 'LISTEN', got %q", filters.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_EstablishedOnly(t *testing.T) {
|
||||
oldListen, oldEstab := filterListen, filterEstab
|
||||
defer func() {
|
||||
filterListen, filterEstab = oldListen, oldEstab
|
||||
}()
|
||||
|
||||
filterListen = false
|
||||
filterEstab = true
|
||||
|
||||
filters, err := BuildFilters([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.State != "ESTABLISHED" {
|
||||
t.Errorf("expected state 'ESTABLISHED', got %q", filters.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_IPv4Flag(t *testing.T) {
|
||||
oldIPv4 := filterIPv4
|
||||
defer func() {
|
||||
filterIPv4 = oldIPv4
|
||||
}()
|
||||
|
||||
filterIPv4 = true
|
||||
|
||||
filters, err := BuildFilters([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !filters.IPv4 {
|
||||
t.Error("expected IPv4 to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_IPv6Flag(t *testing.T) {
|
||||
oldIPv6 := filterIPv6
|
||||
defer func() {
|
||||
filterIPv6 = oldIPv6
|
||||
}()
|
||||
|
||||
filterIPv6 = true
|
||||
|
||||
filters, err := BuildFilters([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !filters.IPv6 {
|
||||
t.Error("expected IPv6 to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFilters_CombinedArgsAndFlags(t *testing.T) {
|
||||
oldTCP := filterTCP
|
||||
defer func() {
|
||||
filterTCP = oldTCP
|
||||
}()
|
||||
|
||||
filterTCP = true
|
||||
|
||||
filters, err := BuildFilters([]string{"lport=80"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if filters.Proto != "tcp" {
|
||||
t.Errorf("expected proto 'tcp', got %q", filters.Proto)
|
||||
}
|
||||
if filters.Lport != 80 {
|
||||
t.Errorf("expected lport 80, got %d", filters.Lport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntime_PreWarmDNS(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
Connections: []collector.Connection{
|
||||
{Laddr: "127.0.0.1", Raddr: "192.168.1.1"},
|
||||
{Laddr: "127.0.0.1", Raddr: "10.0.0.1"},
|
||||
},
|
||||
}
|
||||
|
||||
// should not panic
|
||||
rt.PreWarmDNS()
|
||||
}
|
||||
|
||||
func TestRuntime_PreWarmDNS_Empty(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
Connections: []collector.Connection{},
|
||||
}
|
||||
|
||||
// should not panic with empty connections
|
||||
rt.PreWarmDNS()
|
||||
}
|
||||
|
||||
func TestRuntime_SortConnections(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
Connections: []collector.Connection{
|
||||
{Lport: 443},
|
||||
{Lport: 80},
|
||||
{Lport: 8080},
|
||||
},
|
||||
}
|
||||
|
||||
rt.SortConnections(collector.SortOptions{
|
||||
Field: collector.SortByLport,
|
||||
Direction: collector.SortAsc,
|
||||
})
|
||||
|
||||
if rt.Connections[0].Lport != 80 {
|
||||
t.Errorf("expected first connection to have lport 80, got %d", rt.Connections[0].Lport)
|
||||
}
|
||||
if rt.Connections[1].Lport != 443 {
|
||||
t.Errorf("expected second connection to have lport 443, got %d", rt.Connections[1].Lport)
|
||||
}
|
||||
if rt.Connections[2].Lport != 8080 {
|
||||
t.Errorf("expected third connection to have lport 8080, got %d", rt.Connections[2].Lport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntime_SortConnections_Desc(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
Connections: []collector.Connection{
|
||||
{Lport: 80},
|
||||
{Lport: 443},
|
||||
{Lport: 8080},
|
||||
},
|
||||
}
|
||||
|
||||
rt.SortConnections(collector.SortOptions{
|
||||
Field: collector.SortByLport,
|
||||
Direction: collector.SortDesc,
|
||||
})
|
||||
|
||||
if rt.Connections[0].Lport != 8080 {
|
||||
t.Errorf("expected first connection to have lport 8080, got %d", rt.Connections[0].Lport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyFilter_AllKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
value string
|
||||
validate func(t *testing.T, f *collector.FilterOptions)
|
||||
}{
|
||||
{"proto", "tcp", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Proto != "tcp" {
|
||||
t.Errorf("proto: expected 'tcp', got %q", f.Proto)
|
||||
}
|
||||
}},
|
||||
{"state", "LISTEN", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.State != "LISTEN" {
|
||||
t.Errorf("state: expected 'LISTEN', got %q", f.State)
|
||||
}
|
||||
}},
|
||||
{"pid", "100", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Pid != 100 {
|
||||
t.Errorf("pid: expected 100, got %d", f.Pid)
|
||||
}
|
||||
}},
|
||||
{"proc", "nginx", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Proc != "nginx" {
|
||||
t.Errorf("proc: expected 'nginx', got %q", f.Proc)
|
||||
}
|
||||
}},
|
||||
{"lport", "80", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Lport != 80 {
|
||||
t.Errorf("lport: expected 80, got %d", f.Lport)
|
||||
}
|
||||
}},
|
||||
{"rport", "443", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Rport != 443 {
|
||||
t.Errorf("rport: expected 443, got %d", f.Rport)
|
||||
}
|
||||
}},
|
||||
{"laddr", "127.0.0.1", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Laddr != "127.0.0.1" {
|
||||
t.Errorf("laddr: expected '127.0.0.1', got %q", f.Laddr)
|
||||
}
|
||||
}},
|
||||
{"raddr", "8.8.8.8", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Raddr != "8.8.8.8" {
|
||||
t.Errorf("raddr: expected '8.8.8.8', got %q", f.Raddr)
|
||||
}
|
||||
}},
|
||||
{"contains", "test", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Contains != "test" {
|
||||
t.Errorf("contains: expected 'test', got %q", f.Contains)
|
||||
}
|
||||
}},
|
||||
{"if", "eth0", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Interface != "eth0" {
|
||||
t.Errorf("interface: expected 'eth0', got %q", f.Interface)
|
||||
}
|
||||
}},
|
||||
{"mark", "0xff", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Mark != "0xff" {
|
||||
t.Errorf("mark: expected '0xff', got %q", f.Mark)
|
||||
}
|
||||
}},
|
||||
{"namespace", "ns1", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Namespace != "ns1" {
|
||||
t.Errorf("namespace: expected 'ns1', got %q", f.Namespace)
|
||||
}
|
||||
}},
|
||||
{"inode", "12345", func(t *testing.T, f *collector.FilterOptions) {
|
||||
if f.Inode != 12345 {
|
||||
t.Errorf("inode: expected 12345, got %d", f.Inode)
|
||||
}
|
||||
}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
filters := &collector.FilterOptions{}
|
||||
err := applyFilter(filters, tt.key, tt.value)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
tt.validate(t, filters)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
21
cmd/top.go
21
cmd/top.go
@@ -6,16 +6,15 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/karol-broda/snitch/internal/config"
|
||||
"github.com/karol-broda/snitch/internal/resolver"
|
||||
"github.com/karol-broda/snitch/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// top-specific flags
|
||||
var (
|
||||
topTheme string
|
||||
topInterval time.Duration
|
||||
topResolveAddrs bool
|
||||
topResolvePorts bool
|
||||
topTheme string
|
||||
topInterval time.Duration
|
||||
)
|
||||
|
||||
var topCmd = &cobra.Command{
|
||||
@@ -29,11 +28,16 @@ var topCmd = &cobra.Command{
|
||||
theme = cfg.Defaults.Theme
|
||||
}
|
||||
|
||||
// configure resolver with cache setting
|
||||
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
|
||||
resolver.SetNoCache(effectiveNoCache)
|
||||
|
||||
opts := tui.Options{
|
||||
Theme: theme,
|
||||
Interval: topInterval,
|
||||
ResolveAddrs: topResolveAddrs,
|
||||
ResolvePorts: topResolvePorts,
|
||||
ResolveAddrs: resolveAddrs,
|
||||
ResolvePorts: resolvePorts,
|
||||
NoCache: effectiveNoCache,
|
||||
}
|
||||
|
||||
// if any filter flag is set, use exclusive mode
|
||||
@@ -62,9 +66,8 @@ func init() {
|
||||
// top-specific flags
|
||||
topCmd.Flags().StringVar(&topTheme, "theme", cfg.Defaults.Theme, "Theme for TUI (see 'snitch themes')")
|
||||
topCmd.Flags().DurationVarP(&topInterval, "interval", "i", time.Second, "Refresh interval")
|
||||
topCmd.Flags().BoolVar(&topResolveAddrs, "resolve-addrs", !cfg.Defaults.Numeric, "Resolve IP addresses to hostnames")
|
||||
topCmd.Flags().BoolVar(&topResolvePorts, "resolve-ports", false, "Resolve port numbers to service names")
|
||||
|
||||
// shared filter flags
|
||||
// shared flags
|
||||
addFilterFlags(topCmd)
|
||||
addResolutionFlags(topCmd)
|
||||
}
|
||||
29
cmd/trace.go
29
cmd/trace.go
@@ -7,12 +7,14 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"github.com/karol-broda/snitch/internal/collector"
|
||||
"github.com/karol-broda/snitch/internal/resolver"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/karol-broda/snitch/internal/collector"
|
||||
"github.com/karol-broda/snitch/internal/config"
|
||||
"github.com/karol-broda/snitch/internal/resolver"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -23,11 +25,10 @@ type TraceEvent struct {
|
||||
}
|
||||
|
||||
var (
|
||||
traceInterval time.Duration
|
||||
traceCount int
|
||||
traceInterval time.Duration
|
||||
traceCount int
|
||||
traceOutputFormat string
|
||||
traceNumeric bool
|
||||
traceTimestamp bool
|
||||
traceTimestamp bool
|
||||
)
|
||||
|
||||
var traceCmd = &cobra.Command{
|
||||
@@ -47,6 +48,12 @@ Available filters:
|
||||
}
|
||||
|
||||
func runTraceCommand(args []string) {
|
||||
cfg := config.Get()
|
||||
|
||||
// configure resolver with cache setting
|
||||
effectiveNoCache := noCache || !cfg.Defaults.DNSCache
|
||||
resolver.SetNoCache(effectiveNoCache)
|
||||
|
||||
filters, err := BuildFilters(args)
|
||||
if err != nil {
|
||||
log.Fatalf("Error parsing filters: %v", err)
|
||||
@@ -180,14 +187,16 @@ func printTraceEventHuman(event TraceEvent) {
|
||||
lportStr := fmt.Sprintf("%d", conn.Lport)
|
||||
rportStr := fmt.Sprintf("%d", conn.Rport)
|
||||
|
||||
// Handle name resolution based on numeric flag
|
||||
if !traceNumeric {
|
||||
// apply name resolution
|
||||
if resolveAddrs {
|
||||
if resolvedLaddr := resolver.ResolveAddr(conn.Laddr); resolvedLaddr != conn.Laddr {
|
||||
laddr = resolvedLaddr
|
||||
}
|
||||
if resolvedRaddr := resolver.ResolveAddr(conn.Raddr); resolvedRaddr != conn.Raddr && conn.Raddr != "*" && conn.Raddr != "" {
|
||||
raddr = resolvedRaddr
|
||||
}
|
||||
}
|
||||
if resolvePorts {
|
||||
if resolvedLport := resolver.ResolvePort(conn.Lport, conn.Proto); resolvedLport != fmt.Sprintf("%d", conn.Lport) {
|
||||
lportStr = resolvedLport
|
||||
}
|
||||
@@ -225,9 +234,9 @@ func init() {
|
||||
traceCmd.Flags().DurationVarP(&traceInterval, "interval", "i", time.Second, "Polling interval (e.g., 500ms, 2s)")
|
||||
traceCmd.Flags().IntVarP(&traceCount, "count", "c", 0, "Number of events to capture (0 = unlimited)")
|
||||
traceCmd.Flags().StringVarP(&traceOutputFormat, "output", "o", "human", "Output format (human, json)")
|
||||
traceCmd.Flags().BoolVarP(&traceNumeric, "numeric", "n", false, "Don't resolve hostnames")
|
||||
traceCmd.Flags().BoolVar(&traceTimestamp, "ts", false, "Include timestamp in output")
|
||||
|
||||
// shared filter flags
|
||||
// shared flags
|
||||
addFilterFlags(traceCmd)
|
||||
addResolutionFlags(traceCmd)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user