fix: dns resolution taking long and add caching options (#8)
This commit is contained in:
@@ -25,6 +25,7 @@ type DefaultConfig struct {
|
||||
Units string `mapstructure:"units"`
|
||||
Color string `mapstructure:"color"`
|
||||
Resolve bool `mapstructure:"resolve"`
|
||||
DNSCache bool `mapstructure:"dns_cache"`
|
||||
IPv4 bool `mapstructure:"ipv4"`
|
||||
IPv6 bool `mapstructure:"ipv6"`
|
||||
NoHeaders bool `mapstructure:"no_headers"`
|
||||
@@ -57,6 +58,7 @@ func Load() (*Config, error) {
|
||||
// environment variable bindings for readme-documented variables
|
||||
_ = v.BindEnv("config", "SNITCH_CONFIG")
|
||||
_ = v.BindEnv("defaults.resolve", "SNITCH_RESOLVE")
|
||||
_ = v.BindEnv("defaults.dns_cache", "SNITCH_DNS_CACHE")
|
||||
_ = v.BindEnv("defaults.theme", "SNITCH_THEME")
|
||||
_ = v.BindEnv("defaults.color", "SNITCH_NO_COLOR")
|
||||
|
||||
@@ -90,7 +92,6 @@ func Load() (*Config, error) {
|
||||
}
|
||||
|
||||
func setDefaults(v *viper.Viper) {
|
||||
// Set default values matching the README specification
|
||||
v.SetDefault("defaults.interval", "1s")
|
||||
v.SetDefault("defaults.numeric", false)
|
||||
v.SetDefault("defaults.fields", []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"})
|
||||
@@ -98,6 +99,7 @@ func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("defaults.units", "auto")
|
||||
v.SetDefault("defaults.color", "auto")
|
||||
v.SetDefault("defaults.resolve", true)
|
||||
v.SetDefault("defaults.dns_cache", true)
|
||||
v.SetDefault("defaults.ipv4", false)
|
||||
v.SetDefault("defaults.ipv6", false)
|
||||
v.SetDefault("defaults.no_headers", false)
|
||||
@@ -116,6 +118,11 @@ func handleSpecialEnvVars(v *viper.Viper) {
|
||||
v.Set("defaults.resolve", false)
|
||||
v.Set("defaults.numeric", true)
|
||||
}
|
||||
|
||||
// Handle SNITCH_DNS_CACHE - if set to "0", disable dns caching
|
||||
if os.Getenv("SNITCH_DNS_CACHE") == "0" {
|
||||
v.Set("defaults.dns_cache", false)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the global configuration, loading it if necessary
|
||||
@@ -123,7 +130,6 @@ func Get() *Config {
|
||||
if globalConfig == nil {
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
// Return default config on error
|
||||
return &Config{
|
||||
Defaults: DefaultConfig{
|
||||
Interval: "1s",
|
||||
@@ -133,6 +139,7 @@ func Get() *Config {
|
||||
Units: "auto",
|
||||
Color: "auto",
|
||||
Resolve: true,
|
||||
DNSCache: true,
|
||||
IPv4: false,
|
||||
IPv6: false,
|
||||
NoHeaders: false,
|
||||
|
||||
@@ -14,9 +14,10 @@ var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != ""
|
||||
|
||||
// Resolver handles DNS and service name resolution with caching and timeouts
|
||||
type Resolver struct {
|
||||
timeout time.Duration
|
||||
cache map[string]string
|
||||
mutex sync.RWMutex
|
||||
timeout time.Duration
|
||||
cache map[string]string
|
||||
mutex sync.RWMutex
|
||||
noCache bool
|
||||
}
|
||||
|
||||
// New creates a new resolver with the specified timeout
|
||||
@@ -24,37 +25,44 @@ func New(timeout time.Duration) *Resolver {
|
||||
return &Resolver{
|
||||
timeout: timeout,
|
||||
cache: make(map[string]string),
|
||||
noCache: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetNoCache disables caching - each lookup will hit DNS directly
|
||||
func (r *Resolver) SetNoCache(noCache bool) {
|
||||
r.noCache = noCache
|
||||
}
|
||||
|
||||
// ResolveAddr resolves an IP address to a hostname, with caching
|
||||
func (r *Resolver) ResolveAddr(addr string) string {
|
||||
// Check cache first
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[addr]; exists {
|
||||
// check cache first (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[addr]; exists {
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
// Parse IP to validate it
|
||||
// parse ip to validate it
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
// Not a valid IP, return as-is
|
||||
return addr
|
||||
}
|
||||
|
||||
// Perform resolution with timeout
|
||||
// perform resolution with timeout
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
|
||||
|
||||
resolved := addr // fallback to original address
|
||||
resolved := addr
|
||||
if err == nil && len(names) > 0 {
|
||||
resolved = names[0]
|
||||
// Remove trailing dot if present
|
||||
// remove trailing dot if present
|
||||
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
|
||||
resolved = resolved[:len(resolved)-1]
|
||||
}
|
||||
@@ -65,10 +73,12 @@ func (r *Resolver) ResolveAddr(addr string) string {
|
||||
fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed)
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
r.mutex.Lock()
|
||||
r.cache[addr] = resolved
|
||||
r.mutex.Unlock()
|
||||
// cache the result (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.Lock()
|
||||
r.cache[addr] = resolved
|
||||
r.mutex.Unlock()
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
@@ -81,15 +91,17 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
|
||||
|
||||
cacheKey := strconv.Itoa(port) + "/" + proto
|
||||
|
||||
// Check cache first
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[cacheKey]; exists {
|
||||
// check cache first (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[cacheKey]; exists {
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
// Perform resolution with timeout
|
||||
// perform resolution with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -97,16 +109,18 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
|
||||
|
||||
resolved := strconv.Itoa(port) // fallback to port number
|
||||
if err == nil && service != 0 {
|
||||
// Try to get service name
|
||||
// try to get service name
|
||||
if serviceName := getServiceName(port, proto); serviceName != "" {
|
||||
resolved = serviceName
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
r.mutex.Lock()
|
||||
r.cache[cacheKey] = resolved
|
||||
r.mutex.Unlock()
|
||||
// cache the result (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.Lock()
|
||||
r.cache[cacheKey] = resolved
|
||||
r.mutex.Unlock()
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
@@ -169,22 +183,38 @@ func getServiceName(port int, proto string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Global resolver instance
|
||||
// global resolver instance
|
||||
var globalResolver *Resolver
|
||||
|
||||
// SetGlobalResolver sets the global resolver instance
|
||||
func SetGlobalResolver(timeout time.Duration) {
|
||||
// ResolverOptions configures the global resolver
|
||||
type ResolverOptions struct {
|
||||
Timeout time.Duration
|
||||
NoCache bool
|
||||
}
|
||||
|
||||
// SetGlobalResolver sets the global resolver instance with options
|
||||
func SetGlobalResolver(opts ResolverOptions) {
|
||||
timeout := opts.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 200 * time.Millisecond
|
||||
}
|
||||
globalResolver = New(timeout)
|
||||
globalResolver.SetNoCache(opts.NoCache)
|
||||
}
|
||||
|
||||
// GetGlobalResolver returns the global resolver instance
|
||||
func GetGlobalResolver() *Resolver {
|
||||
if globalResolver == nil {
|
||||
globalResolver = New(200 * time.Millisecond) // Default timeout
|
||||
globalResolver = New(200 * time.Millisecond)
|
||||
}
|
||||
return globalResolver
|
||||
}
|
||||
|
||||
// SetNoCache configures whether the global resolver bypasses cache
|
||||
func SetNoCache(noCache bool) {
|
||||
GetGlobalResolver().SetNoCache(noCache)
|
||||
}
|
||||
|
||||
// ResolveAddr is a convenience function using the global resolver
|
||||
func ResolveAddr(addr string) string {
|
||||
return GetGlobalResolver().ResolveAddr(addr)
|
||||
@@ -199,3 +229,48 @@ func ResolvePort(port int, proto string) string {
|
||||
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
|
||||
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
|
||||
}
|
||||
|
||||
// ResolveAddrsParallel resolves multiple addresses concurrently and caches results.
|
||||
// This should be called before rendering to pre-warm the cache.
|
||||
func (r *Resolver) ResolveAddrsParallel(addrs []string) {
|
||||
// dedupe and filter addresses that need resolution
|
||||
unique := make(map[string]struct{})
|
||||
for _, addr := range addrs {
|
||||
if addr == "" || addr == "*" {
|
||||
continue
|
||||
}
|
||||
// skip if already cached
|
||||
r.mutex.RLock()
|
||||
_, exists := r.cache[addr]
|
||||
r.mutex.RUnlock()
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
unique[addr] = struct{}{}
|
||||
}
|
||||
|
||||
if len(unique) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
// limit concurrency to avoid overwhelming dns
|
||||
sem := make(chan struct{}, 32)
|
||||
|
||||
for addr := range unique {
|
||||
wg.Add(1)
|
||||
go func(a string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
r.ResolveAddr(a)
|
||||
}(addr)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ResolveAddrsParallel is a convenience function using the global resolver
|
||||
func ResolveAddrsParallel(addrs []string) {
|
||||
GetGlobalResolver().ResolveAddrsParallel(addrs)
|
||||
}
|
||||
|
||||
159
internal/resolver/resolver_bench_test.go
Normal file
159
internal/resolver/resolver_bench_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkResolveAddr_CacheHit(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
addr := "127.0.0.1"
|
||||
|
||||
// pre-populate cache
|
||||
r.ResolveAddr(addr)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.ResolveAddr(addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkResolveAddr_CacheMiss(b *testing.B) {
|
||||
r := New(10 * time.Millisecond) // short timeout for faster benchmarks
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// use different addresses to avoid cache hits
|
||||
addr := fmt.Sprintf("127.0.0.%d", i%256)
|
||||
r.ClearCache() // clear cache to force miss
|
||||
r.ResolveAddr(addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkResolveAddr_NoCache(b *testing.B) {
|
||||
r := New(10 * time.Millisecond)
|
||||
r.SetNoCache(true)
|
||||
addr := "127.0.0.1"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.ResolveAddr(addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkResolvePort_CacheHit(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// pre-populate cache
|
||||
r.ResolvePort(80, "tcp")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.ResolvePort(80, "tcp")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkResolvePort_WellKnown(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.ClearCache()
|
||||
r.ResolvePort(443, "tcp")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetServiceName(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
getServiceName(80, "tcp")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetServiceName_NotFound(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
getServiceName(12345, "tcp")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkResolveAddrsParallel_10(b *testing.B) {
|
||||
benchmarkResolveAddrsParallel(b, 10)
|
||||
}
|
||||
|
||||
func BenchmarkResolveAddrsParallel_100(b *testing.B) {
|
||||
benchmarkResolveAddrsParallel(b, 100)
|
||||
}
|
||||
|
||||
func BenchmarkResolveAddrsParallel_1000(b *testing.B) {
|
||||
benchmarkResolveAddrsParallel(b, 1000)
|
||||
}
|
||||
|
||||
func benchmarkResolveAddrsParallel(b *testing.B, count int) {
|
||||
addrs := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
addrs[i] = fmt.Sprintf("127.0.%d.%d", i/256, i%256)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r := New(10 * time.Millisecond)
|
||||
r.ResolveAddrsParallel(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConcurrentResolveAddr(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
addr := "127.0.0.1"
|
||||
|
||||
// pre-populate cache
|
||||
r.ResolveAddr(addr)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
r.ResolveAddr(addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkConcurrentResolvePort(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// pre-populate cache
|
||||
r.ResolvePort(80, "tcp")
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
r.ResolvePort(80, "tcp")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkGetCacheSize(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// populate with some entries
|
||||
for i := 0; i < 100; i++ {
|
||||
r.ResolvePort(i+1, "tcp")
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.GetCacheSize()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkClearCache(b *testing.B) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// populate and clear
|
||||
for j := 0; j < 10; j++ {
|
||||
r.ResolvePort(j+1, "tcp")
|
||||
}
|
||||
r.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
387
internal/resolver/resolver_test.go
Normal file
387
internal/resolver/resolver_test.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
if r == nil {
|
||||
t.Fatal("expected non-nil resolver")
|
||||
}
|
||||
if r.timeout != 100*time.Millisecond {
|
||||
t.Errorf("expected timeout 100ms, got %v", r.timeout)
|
||||
}
|
||||
if r.cache == nil {
|
||||
t.Error("expected cache to be initialized")
|
||||
}
|
||||
if r.noCache {
|
||||
t.Error("expected noCache to be false by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetNoCache(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
r.SetNoCache(true)
|
||||
if !r.noCache {
|
||||
t.Error("expected noCache to be true")
|
||||
}
|
||||
|
||||
r.SetNoCache(false)
|
||||
if r.noCache {
|
||||
t.Error("expected noCache to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddr_InvalidIP(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// invalid ip should return as-is
|
||||
result := r.ResolveAddr("not-an-ip")
|
||||
if result != "not-an-ip" {
|
||||
t.Errorf("expected 'not-an-ip', got %q", result)
|
||||
}
|
||||
|
||||
// empty string should return as-is
|
||||
result = r.ResolveAddr("")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddr_Caching(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// first call should cache
|
||||
addr := "127.0.0.1"
|
||||
result1 := r.ResolveAddr(addr)
|
||||
|
||||
// verify cache is populated
|
||||
if r.GetCacheSize() != 1 {
|
||||
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
|
||||
}
|
||||
|
||||
// second call should use cache
|
||||
result2 := r.ResolveAddr(addr)
|
||||
if result1 != result2 {
|
||||
t.Errorf("expected same result from cache, got %q and %q", result1, result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddr_NoCacheMode(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
r.SetNoCache(true)
|
||||
|
||||
addr := "127.0.0.1"
|
||||
r.ResolveAddr(addr)
|
||||
|
||||
// cache should remain empty when noCache is enabled
|
||||
if r.GetCacheSize() != 0 {
|
||||
t.Errorf("expected cache size 0 with noCache, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePort_Zero(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
result := r.ResolvePort(0, "tcp")
|
||||
if result != "0" {
|
||||
t.Errorf("expected '0' for port 0, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePort_WellKnown(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
port int
|
||||
proto string
|
||||
expected string
|
||||
}{
|
||||
{80, "tcp", "http"},
|
||||
{443, "tcp", "https"},
|
||||
{22, "tcp", "ssh"},
|
||||
{53, "udp", "domain"},
|
||||
{5432, "tcp", "postgresql"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := r.ResolvePort(tt.port, tt.proto)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ResolvePort(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePort_Caching(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
r.ResolvePort(80, "tcp")
|
||||
r.ResolvePort(443, "tcp")
|
||||
|
||||
if r.GetCacheSize() != 2 {
|
||||
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
|
||||
}
|
||||
|
||||
// same port/proto should not add new entry
|
||||
r.ResolvePort(80, "tcp")
|
||||
if r.GetCacheSize() != 2 {
|
||||
t.Errorf("expected cache size still 2, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddrPort(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
addr, port := r.ResolveAddrPort("127.0.0.1", 80, "tcp")
|
||||
|
||||
if addr == "" {
|
||||
t.Error("expected non-empty address")
|
||||
}
|
||||
if port != "http" {
|
||||
t.Errorf("expected port 'http', got %q", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCache(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
r.ResolveAddr("127.0.0.1")
|
||||
r.ResolvePort(80, "tcp")
|
||||
|
||||
if r.GetCacheSize() == 0 {
|
||||
t.Error("expected non-empty cache before clear")
|
||||
}
|
||||
|
||||
r.ClearCache()
|
||||
|
||||
if r.GetCacheSize() != 0 {
|
||||
t.Errorf("expected empty cache after clear, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheSize(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
if r.GetCacheSize() != 0 {
|
||||
t.Errorf("expected initial cache size 0, got %d", r.GetCacheSize())
|
||||
}
|
||||
|
||||
r.ResolveAddr("127.0.0.1")
|
||||
if r.GetCacheSize() != 1 {
|
||||
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServiceName(t *testing.T) {
|
||||
tests := []struct {
|
||||
port int
|
||||
proto string
|
||||
expected string
|
||||
}{
|
||||
{80, "tcp", "http"},
|
||||
{443, "tcp", "https"},
|
||||
{22, "tcp", "ssh"},
|
||||
{53, "tcp", "domain"},
|
||||
{53, "udp", "domain"},
|
||||
{12345, "tcp", ""},
|
||||
{0, "tcp", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := getServiceName(tt.port, tt.proto)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getServiceName(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddrsParallel(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
addrs := []string{
|
||||
"127.0.0.1",
|
||||
"127.0.0.2",
|
||||
"127.0.0.3",
|
||||
"", // should be skipped
|
||||
"*", // should be skipped
|
||||
}
|
||||
|
||||
r.ResolveAddrsParallel(addrs)
|
||||
|
||||
// should have cached 3 addresses (excluding empty and *)
|
||||
if r.GetCacheSize() != 3 {
|
||||
t.Errorf("expected cache size 3, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddrsParallel_Dedupe(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
addrs := []string{
|
||||
"127.0.0.1",
|
||||
"127.0.0.1",
|
||||
"127.0.0.1",
|
||||
"127.0.0.2",
|
||||
}
|
||||
|
||||
r.ResolveAddrsParallel(addrs)
|
||||
|
||||
// should have cached 2 unique addresses
|
||||
if r.GetCacheSize() != 2 {
|
||||
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddrsParallel_SkipsCached(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// pre-cache one address
|
||||
r.ResolveAddr("127.0.0.1")
|
||||
|
||||
addrs := []string{
|
||||
"127.0.0.1", // already cached
|
||||
"127.0.0.2", // not cached
|
||||
}
|
||||
|
||||
initialSize := r.GetCacheSize()
|
||||
r.ResolveAddrsParallel(addrs)
|
||||
|
||||
// should have added 1 more
|
||||
if r.GetCacheSize() != initialSize+1 {
|
||||
t.Errorf("expected cache size %d, got %d", initialSize+1, r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddrsParallel_Empty(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// should not panic with empty input
|
||||
r.ResolveAddrsParallel([]string{})
|
||||
r.ResolveAddrsParallel(nil)
|
||||
|
||||
if r.GetCacheSize() != 0 {
|
||||
t.Errorf("expected cache size 0, got %d", r.GetCacheSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalResolver(t *testing.T) {
|
||||
// reset global resolver
|
||||
globalResolver = nil
|
||||
|
||||
r := GetGlobalResolver()
|
||||
if r == nil {
|
||||
t.Fatal("expected non-nil global resolver")
|
||||
}
|
||||
|
||||
// should return same instance
|
||||
r2 := GetGlobalResolver()
|
||||
if r != r2 {
|
||||
t.Error("expected same global resolver instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetGlobalResolver(t *testing.T) {
|
||||
SetGlobalResolver(ResolverOptions{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
NoCache: true,
|
||||
})
|
||||
|
||||
r := GetGlobalResolver()
|
||||
if r.timeout != 500*time.Millisecond {
|
||||
t.Errorf("expected timeout 500ms, got %v", r.timeout)
|
||||
}
|
||||
if !r.noCache {
|
||||
t.Error("expected noCache to be true")
|
||||
}
|
||||
|
||||
// reset for other tests
|
||||
globalResolver = nil
|
||||
}
|
||||
|
||||
func TestSetGlobalResolver_DefaultTimeout(t *testing.T) {
|
||||
SetGlobalResolver(ResolverOptions{
|
||||
Timeout: 0, // should use default
|
||||
})
|
||||
|
||||
r := GetGlobalResolver()
|
||||
if r.timeout != 200*time.Millisecond {
|
||||
t.Errorf("expected default timeout 200ms, got %v", r.timeout)
|
||||
}
|
||||
|
||||
// reset for other tests
|
||||
globalResolver = nil
|
||||
}
|
||||
|
||||
func TestGlobalConvenienceFunctions(t *testing.T) {
|
||||
globalResolver = nil
|
||||
|
||||
// test global ResolveAddr
|
||||
result := ResolveAddr("127.0.0.1")
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result from global ResolveAddr")
|
||||
}
|
||||
|
||||
// test global ResolvePort
|
||||
port := ResolvePort(80, "tcp")
|
||||
if port != "http" {
|
||||
t.Errorf("expected 'http', got %q", port)
|
||||
}
|
||||
|
||||
// test global ResolveAddrPort
|
||||
addr, portStr := ResolveAddrPort("127.0.0.1", 443, "tcp")
|
||||
if addr == "" {
|
||||
t.Error("expected non-empty address")
|
||||
}
|
||||
if portStr != "https" {
|
||||
t.Errorf("expected 'https', got %q", portStr)
|
||||
}
|
||||
|
||||
// test global SetNoCache
|
||||
SetNoCache(true)
|
||||
if !GetGlobalResolver().noCache {
|
||||
t.Error("expected global noCache to be true")
|
||||
}
|
||||
|
||||
// reset
|
||||
globalResolver = nil
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
addr := "127.0.0.1"
|
||||
r.ResolveAddr(addr)
|
||||
r.ResolvePort(80+n%10, "tcp")
|
||||
r.GetCacheSize()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// should not panic and cache should have entries
|
||||
if r.GetCacheSize() == 0 {
|
||||
t.Error("expected non-empty cache after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAddr_TrailingDot(t *testing.T) {
|
||||
// this test verifies the trailing dot removal logic
|
||||
// by checking the internal logic works correctly
|
||||
r := New(100 * time.Millisecond)
|
||||
|
||||
// localhost should resolve and have trailing dot removed
|
||||
result := r.ResolveAddr("127.0.0.1")
|
||||
if len(result) > 0 && result[len(result)-1] == '.' {
|
||||
t.Error("expected trailing dot to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package tui
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/karol-broda/snitch/internal/collector"
|
||||
"github.com/karol-broda/snitch/internal/resolver"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -35,11 +36,20 @@ func (m model) tick() tea.Cmd {
|
||||
}
|
||||
|
||||
func (m model) fetchData() tea.Cmd {
|
||||
resolveAddrs := m.resolveAddrs
|
||||
return func() tea.Msg {
|
||||
conns, err := collector.GetConnections()
|
||||
if err != nil {
|
||||
return errMsg{err}
|
||||
}
|
||||
// pre-warm dns cache in parallel if resolution is enabled
|
||||
if resolveAddrs {
|
||||
addrs := make([]string, 0, len(conns)*2)
|
||||
for _, c := range conns {
|
||||
addrs = append(addrs, c.Laddr, c.Raddr)
|
||||
}
|
||||
resolver.ResolveAddrsParallel(addrs)
|
||||
}
|
||||
return dataMsg{connections: conns}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ type Options struct {
|
||||
FilterSet bool // true if user specified any filter flags
|
||||
ResolveAddrs bool // when true, resolve IP addresses to hostnames
|
||||
ResolvePorts bool // when true, resolve port numbers to service names
|
||||
NoCache bool // when true, disable DNS caching
|
||||
}
|
||||
|
||||
func New(opts Options) model {
|
||||
|
||||
Reference in New Issue
Block a user