fix: dns resolution taking long and add caching options (#8)
This commit is contained in:
@@ -14,9 +14,10 @@ var debugTiming = os.Getenv("SNITCH_DEBUG_TIMING") != ""
|
||||
|
||||
// Resolver handles DNS and service name resolution with caching and timeouts
|
||||
type Resolver struct {
|
||||
timeout time.Duration
|
||||
cache map[string]string
|
||||
mutex sync.RWMutex
|
||||
timeout time.Duration
|
||||
cache map[string]string
|
||||
mutex sync.RWMutex
|
||||
noCache bool
|
||||
}
|
||||
|
||||
// New creates a new resolver with the specified timeout
|
||||
@@ -24,37 +25,44 @@ func New(timeout time.Duration) *Resolver {
|
||||
return &Resolver{
|
||||
timeout: timeout,
|
||||
cache: make(map[string]string),
|
||||
noCache: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetNoCache disables caching - each lookup will hit DNS directly
|
||||
func (r *Resolver) SetNoCache(noCache bool) {
|
||||
r.noCache = noCache
|
||||
}
|
||||
|
||||
// ResolveAddr resolves an IP address to a hostname, with caching
|
||||
func (r *Resolver) ResolveAddr(addr string) string {
|
||||
// Check cache first
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[addr]; exists {
|
||||
// check cache first (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[addr]; exists {
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
// Parse IP to validate it
|
||||
// parse ip to validate it
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
// Not a valid IP, return as-is
|
||||
return addr
|
||||
}
|
||||
|
||||
// Perform resolution with timeout
|
||||
// perform resolution with timeout
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
names, err := net.DefaultResolver.LookupAddr(ctx, addr)
|
||||
|
||||
resolved := addr // fallback to original address
|
||||
resolved := addr
|
||||
if err == nil && len(names) > 0 {
|
||||
resolved = names[0]
|
||||
// Remove trailing dot if present
|
||||
// remove trailing dot if present
|
||||
if len(resolved) > 0 && resolved[len(resolved)-1] == '.' {
|
||||
resolved = resolved[:len(resolved)-1]
|
||||
}
|
||||
@@ -65,10 +73,12 @@ func (r *Resolver) ResolveAddr(addr string) string {
|
||||
fmt.Fprintf(os.Stderr, "[timing] slow DNS lookup: %s -> %s (%v)\n", addr, resolved, elapsed)
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
r.mutex.Lock()
|
||||
r.cache[addr] = resolved
|
||||
r.mutex.Unlock()
|
||||
// cache the result (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.Lock()
|
||||
r.cache[addr] = resolved
|
||||
r.mutex.Unlock()
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
@@ -81,15 +91,17 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
|
||||
|
||||
cacheKey := strconv.Itoa(port) + "/" + proto
|
||||
|
||||
// Check cache first
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[cacheKey]; exists {
|
||||
// check cache first (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.RLock()
|
||||
if cached, exists := r.cache[cacheKey]; exists {
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
// Perform resolution with timeout
|
||||
// perform resolution with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -97,16 +109,18 @@ func (r *Resolver) ResolvePort(port int, proto string) string {
|
||||
|
||||
resolved := strconv.Itoa(port) // fallback to port number
|
||||
if err == nil && service != 0 {
|
||||
// Try to get service name
|
||||
// try to get service name
|
||||
if serviceName := getServiceName(port, proto); serviceName != "" {
|
||||
resolved = serviceName
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
r.mutex.Lock()
|
||||
r.cache[cacheKey] = resolved
|
||||
r.mutex.Unlock()
|
||||
// cache the result (unless caching is disabled)
|
||||
if !r.noCache {
|
||||
r.mutex.Lock()
|
||||
r.cache[cacheKey] = resolved
|
||||
r.mutex.Unlock()
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
@@ -169,22 +183,38 @@ func getServiceName(port int, proto string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Global resolver instance
|
||||
// global resolver instance
|
||||
var globalResolver *Resolver
|
||||
|
||||
// SetGlobalResolver sets the global resolver instance
|
||||
func SetGlobalResolver(timeout time.Duration) {
|
||||
// ResolverOptions configures the global resolver
|
||||
type ResolverOptions struct {
|
||||
Timeout time.Duration
|
||||
NoCache bool
|
||||
}
|
||||
|
||||
// SetGlobalResolver sets the global resolver instance with options
|
||||
func SetGlobalResolver(opts ResolverOptions) {
|
||||
timeout := opts.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 200 * time.Millisecond
|
||||
}
|
||||
globalResolver = New(timeout)
|
||||
globalResolver.SetNoCache(opts.NoCache)
|
||||
}
|
||||
|
||||
// GetGlobalResolver returns the global resolver instance
|
||||
func GetGlobalResolver() *Resolver {
|
||||
if globalResolver == nil {
|
||||
globalResolver = New(200 * time.Millisecond) // Default timeout
|
||||
globalResolver = New(200 * time.Millisecond)
|
||||
}
|
||||
return globalResolver
|
||||
}
|
||||
|
||||
// SetNoCache configures whether the global resolver bypasses cache
|
||||
func SetNoCache(noCache bool) {
|
||||
GetGlobalResolver().SetNoCache(noCache)
|
||||
}
|
||||
|
||||
// ResolveAddr is a convenience function using the global resolver
|
||||
func ResolveAddr(addr string) string {
|
||||
return GetGlobalResolver().ResolveAddr(addr)
|
||||
@@ -199,3 +229,48 @@ func ResolvePort(port int, proto string) string {
|
||||
func ResolveAddrPort(addr string, port int, proto string) (string, string) {
|
||||
return GetGlobalResolver().ResolveAddrPort(addr, port, proto)
|
||||
}
|
||||
|
||||
// ResolveAddrsParallel resolves multiple addresses concurrently and caches results.
|
||||
// This should be called before rendering to pre-warm the cache.
|
||||
func (r *Resolver) ResolveAddrsParallel(addrs []string) {
|
||||
// dedupe and filter addresses that need resolution
|
||||
unique := make(map[string]struct{})
|
||||
for _, addr := range addrs {
|
||||
if addr == "" || addr == "*" {
|
||||
continue
|
||||
}
|
||||
// skip if already cached
|
||||
r.mutex.RLock()
|
||||
_, exists := r.cache[addr]
|
||||
r.mutex.RUnlock()
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
unique[addr] = struct{}{}
|
||||
}
|
||||
|
||||
if len(unique) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
// limit concurrency to avoid overwhelming dns
|
||||
sem := make(chan struct{}, 32)
|
||||
|
||||
for addr := range unique {
|
||||
wg.Add(1)
|
||||
go func(a string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
r.ResolveAddr(a)
|
||||
}(addr)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ResolveAddrsParallel is a convenience function using the global resolver
|
||||
func ResolveAddrsParallel(addrs []string) {
|
||||
GetGlobalResolver().ResolveAddrsParallel(addrs)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user