fix: dns resolution taking long and add caching options (#8)

This commit is contained in:
Karol Broda
2025-12-24 11:12:39 +01:00
committed by Karol Broda
parent 933da8bad2
commit 49948de0ed
13 changed files with 1302 additions and 67 deletions

View File

@@ -14,9 +14,10 @@ var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != ""
// Resolver handles DNS and service name resolution with caching and timeouts
type Resolver struct {
timeout time.Duration
cache map[string]string
mutex sync.RWMutex
timeout time.Duration
cache map[string]string
mutex sync.RWMutex
noCache bool
}
// New creates a new resolver with the specified timeout
@@ -24,37 +25,44 @@ func New(timeout time.Duration) *Resolver {
return &Resolver{
timeout: timeout,
cache: make(map[string]string),
noCache: false,
}
}
// SetNoCache disables caching - each lookup will hit DNS directly
func (r *Resolver) SetNoCache(noCache bool) {
r.noCache = noCache
}
// ResolveAddr resolves an IP address to a hostname, with caching
func (r *Resolver) ResolveAddr(addr string) string {
// Check cache first
r.mutex.RLock()
if cached, exists := r.cache[addr]; exists {
// check cache first (unless caching is disabled)
if !r.noCache {
r.mutex.RLock()
if cached, exists := r.cache[addr]; exists {
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock()
// Parse IP to validate it
// parse ip to validate it
ip := net.ParseIP(addr)
if ip == nil {
// Not a valid IP, return as-is
return addr
}
// Perform resolution with timeout
// perform resolution with timeout
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
resolved := addr // fallback to original address
resolved := addr
if err == nil && len(names) > 0 {
resolved = names[0]
// Remove trailing dot if present
// remove trailing dot if present
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
resolved = resolved[:len(resolved)-1]
}
@@ -65,10 +73,12 @@ func (r *Resolver) ResolveAddr(addr string) string {
fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed)
}
// Cache the result
r.mutex.Lock()
r.cache[addr] = resolved
r.mutex.Unlock()
// cache the result (unless caching is disabled)
if !r.noCache {
r.mutex.Lock()
r.cache[addr] = resolved
r.mutex.Unlock()
}
return resolved
}
@@ -81,15 +91,17 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
cacheKey := strconv.Itoa(port) + "/" + proto
// Check cache first
r.mutex.RLock()
if cached, exists := r.cache[cacheKey]; exists {
// check cache first (unless caching is disabled)
if !r.noCache {
r.mutex.RLock()
if cached, exists := r.cache[cacheKey]; exists {
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock()
return cached
}
r.mutex.RUnlock()
// Perform resolution with timeout
// perform resolution with timeout
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
@@ -97,16 +109,18 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
resolved := strconv.Itoa(port) // fallback to port number
if err == nil && service != 0 {
// Try to get service name
// try to get service name
if serviceName := getServiceName(port, proto); serviceName != "" {
resolved = serviceName
}
}
// Cache the result
r.mutex.Lock()
r.cache[cacheKey] = resolved
r.mutex.Unlock()
// cache the result (unless caching is disabled)
if !r.noCache {
r.mutex.Lock()
r.cache[cacheKey] = resolved
r.mutex.Unlock()
}
return resolved
}
@@ -169,22 +183,38 @@ func getServiceName(port int, proto string) string {
return ""
}
// Global resolver instance
// global resolver instance
var globalResolver *Resolver
// SetGlobalResolver sets the global resolver instance
func SetGlobalResolver(timeout time.Duration) {
// ResolverOptions configures the global resolver
type ResolverOptions struct {
Timeout time.Duration
NoCache bool
}
// SetGlobalResolver sets the global resolver instance with options
func SetGlobalResolver(opts ResolverOptions) {
timeout := opts.Timeout
if timeout == 0 {
timeout = 200 * time.Millisecond
}
globalResolver = New(timeout)
globalResolver.SetNoCache(opts.NoCache)
}
// GetGlobalResolver returns the global resolver instance
func GetGlobalResolver() *Resolver {
if globalResolver == nil {
globalResolver = New(200 * time.Millisecond) // Default timeout
globalResolver = New(200 * time.Millisecond)
}
return globalResolver
}
// SetNoCache configures whether the global resolver bypasses cache
func SetNoCache(noCache bool) {
GetGlobalResolver().SetNoCache(noCache)
}
// ResolveAddr is a convenience function using the global resolver
func ResolveAddr(addr string) string {
return GetGlobalResolver().ResolveAddr(addr)
@@ -199,3 +229,48 @@ func ResolvePort(port int, proto string) string {
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
}
// ResolveAddrsParallel resolves multiple addresses concurrently and caches results.
// This should be called before rendering to pre-warm the cache.
func (r *Resolver) ResolveAddrsParallel(addrs []string) {
// dedupe and filter addresses that need resolution
unique := make(map[string]struct{})
for _, addr := range addrs {
if addr == "" || addr == "*" {
continue
}
// skip if already cached
r.mutex.RLock()
_, exists := r.cache[addr]
r.mutex.RUnlock()
if exists {
continue
}
unique[addr] = struct{}{}
}
if len(unique) == 0 {
return
}
var wg sync.WaitGroup
// limit concurrency to avoid overwhelming dns
sem := make(chan struct{}, 32)
for addr := range unique {
wg.Add(1)
go func(a string) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
r.ResolveAddr(a)
}(addr)
}
wg.Wait()
}
// ResolveAddrsParallel is a convenience function using the global resolver
func ResolveAddrsParallel(addrs []string) {
GetGlobalResolver().ResolveAddrsParallel(addrs)
}