388 lines
8.2 KiB
Go
388 lines
8.2 KiB
Go
package resolver
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNew(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
if r == nil {
|
|
t.Fatal("expected non-nil resolver")
|
|
}
|
|
if r.timeout != 100*time.Millisecond {
|
|
t.Errorf("expected timeout 100ms, got %v", r.timeout)
|
|
}
|
|
if r.cache == nil {
|
|
t.Error("expected cache to be initialized")
|
|
}
|
|
if r.noCache {
|
|
t.Error("expected noCache to be false by default")
|
|
}
|
|
}
|
|
|
|
func TestSetNoCache(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
r.SetNoCache(true)
|
|
if !r.noCache {
|
|
t.Error("expected noCache to be true")
|
|
}
|
|
|
|
r.SetNoCache(false)
|
|
if r.noCache {
|
|
t.Error("expected noCache to be false")
|
|
}
|
|
}
|
|
|
|
func TestResolveAddr_InvalidIP(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
// invalid ip should return as-is
|
|
result := r.ResolveAddr("not-an-ip")
|
|
if result != "not-an-ip" {
|
|
t.Errorf("expected 'not-an-ip', got %q", result)
|
|
}
|
|
|
|
// empty string should return as-is
|
|
result = r.ResolveAddr("")
|
|
if result != "" {
|
|
t.Errorf("expected empty string, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestResolveAddr_Caching(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
// first call should cache
|
|
addr := "127.0.0.1"
|
|
result1 := r.ResolveAddr(addr)
|
|
|
|
// verify cache is populated
|
|
if r.GetCacheSize() != 1 {
|
|
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
|
|
}
|
|
|
|
// second call should use cache
|
|
result2 := r.ResolveAddr(addr)
|
|
if result1 != result2 {
|
|
t.Errorf("expected same result from cache, got %q and %q", result1, result2)
|
|
}
|
|
}
|
|
|
|
func TestResolveAddr_NoCacheMode(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
r.SetNoCache(true)
|
|
|
|
addr := "127.0.0.1"
|
|
r.ResolveAddr(addr)
|
|
|
|
// cache should remain empty when noCache is enabled
|
|
if r.GetCacheSize() != 0 {
|
|
t.Errorf("expected cache size 0 with noCache, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestResolvePort_Zero(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
result := r.ResolvePort(0, "tcp")
|
|
if result != "0" {
|
|
t.Errorf("expected '0' for port 0, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestResolvePort_WellKnown(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
tests := []struct {
|
|
port int
|
|
proto string
|
|
expected string
|
|
}{
|
|
{80, "tcp", "http"},
|
|
{443, "tcp", "https"},
|
|
{22, "tcp", "ssh"},
|
|
{53, "udp", "domain"},
|
|
{5432, "tcp", "postgresql"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := r.ResolvePort(tt.port, tt.proto)
|
|
if result != tt.expected {
|
|
t.Errorf("ResolvePort(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestResolvePort_Caching(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
r.ResolvePort(80, "tcp")
|
|
r.ResolvePort(443, "tcp")
|
|
|
|
if r.GetCacheSize() != 2 {
|
|
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
|
|
}
|
|
|
|
// same port/proto should not add new entry
|
|
r.ResolvePort(80, "tcp")
|
|
if r.GetCacheSize() != 2 {
|
|
t.Errorf("expected cache size still 2, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestResolveAddrPort(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
addr, port := r.ResolveAddrPort("127.0.0.1", 80, "tcp")
|
|
|
|
if addr == "" {
|
|
t.Error("expected non-empty address")
|
|
}
|
|
if port != "http" {
|
|
t.Errorf("expected port 'http', got %q", port)
|
|
}
|
|
}
|
|
|
|
func TestClearCache(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
r.ResolveAddr("127.0.0.1")
|
|
r.ResolvePort(80, "tcp")
|
|
|
|
if r.GetCacheSize() == 0 {
|
|
t.Error("expected non-empty cache before clear")
|
|
}
|
|
|
|
r.ClearCache()
|
|
|
|
if r.GetCacheSize() != 0 {
|
|
t.Errorf("expected empty cache after clear, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestGetCacheSize(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
if r.GetCacheSize() != 0 {
|
|
t.Errorf("expected initial cache size 0, got %d", r.GetCacheSize())
|
|
}
|
|
|
|
r.ResolveAddr("127.0.0.1")
|
|
if r.GetCacheSize() != 1 {
|
|
t.Errorf("expected cache size 1, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestGetServiceName(t *testing.T) {
|
|
tests := []struct {
|
|
port int
|
|
proto string
|
|
expected string
|
|
}{
|
|
{80, "tcp", "http"},
|
|
{443, "tcp", "https"},
|
|
{22, "tcp", "ssh"},
|
|
{53, "tcp", "domain"},
|
|
{53, "udp", "domain"},
|
|
{12345, "tcp", ""},
|
|
{0, "tcp", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := getServiceName(tt.port, tt.proto)
|
|
if result != tt.expected {
|
|
t.Errorf("getServiceName(%d, %q) = %q, want %q", tt.port, tt.proto, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestResolveAddrsParallel(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
addrs := []string{
|
|
"127.0.0.1",
|
|
"127.0.0.2",
|
|
"127.0.0.3",
|
|
"", // should be skipped
|
|
"*", // should be skipped
|
|
}
|
|
|
|
r.ResolveAddrsParallel(addrs)
|
|
|
|
// should have cached 3 addresses (excluding empty and *)
|
|
if r.GetCacheSize() != 3 {
|
|
t.Errorf("expected cache size 3, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestResolveAddrsParallel_Dedupe(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
addrs := []string{
|
|
"127.0.0.1",
|
|
"127.0.0.1",
|
|
"127.0.0.1",
|
|
"127.0.0.2",
|
|
}
|
|
|
|
r.ResolveAddrsParallel(addrs)
|
|
|
|
// should have cached 2 unique addresses
|
|
if r.GetCacheSize() != 2 {
|
|
t.Errorf("expected cache size 2, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestResolveAddrsParallel_SkipsCached(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
// pre-cache one address
|
|
r.ResolveAddr("127.0.0.1")
|
|
|
|
addrs := []string{
|
|
"127.0.0.1", // already cached
|
|
"127.0.0.2", // not cached
|
|
}
|
|
|
|
initialSize := r.GetCacheSize()
|
|
r.ResolveAddrsParallel(addrs)
|
|
|
|
// should have added 1 more
|
|
if r.GetCacheSize() != initialSize+1 {
|
|
t.Errorf("expected cache size %d, got %d", initialSize+1, r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestResolveAddrsParallel_Empty(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
// should not panic with empty input
|
|
r.ResolveAddrsParallel([]string{})
|
|
r.ResolveAddrsParallel(nil)
|
|
|
|
if r.GetCacheSize() != 0 {
|
|
t.Errorf("expected cache size 0, got %d", r.GetCacheSize())
|
|
}
|
|
}
|
|
|
|
func TestGlobalResolver(t *testing.T) {
|
|
// reset global resolver
|
|
globalResolver = nil
|
|
|
|
r := GetGlobalResolver()
|
|
if r == nil {
|
|
t.Fatal("expected non-nil global resolver")
|
|
}
|
|
|
|
// should return same instance
|
|
r2 := GetGlobalResolver()
|
|
if r != r2 {
|
|
t.Error("expected same global resolver instance")
|
|
}
|
|
}
|
|
|
|
func TestSetGlobalResolver(t *testing.T) {
|
|
SetGlobalResolver(ResolverOptions{
|
|
Timeout: 500 * time.Millisecond,
|
|
NoCache: true,
|
|
})
|
|
|
|
r := GetGlobalResolver()
|
|
if r.timeout != 500*time.Millisecond {
|
|
t.Errorf("expected timeout 500ms, got %v", r.timeout)
|
|
}
|
|
if !r.noCache {
|
|
t.Error("expected noCache to be true")
|
|
}
|
|
|
|
// reset for other tests
|
|
globalResolver = nil
|
|
}
|
|
|
|
func TestSetGlobalResolver_DefaultTimeout(t *testing.T) {
|
|
SetGlobalResolver(ResolverOptions{
|
|
Timeout: 0, // should use default
|
|
})
|
|
|
|
r := GetGlobalResolver()
|
|
if r.timeout != 200*time.Millisecond {
|
|
t.Errorf("expected default timeout 200ms, got %v", r.timeout)
|
|
}
|
|
|
|
// reset for other tests
|
|
globalResolver = nil
|
|
}
|
|
|
|
func TestGlobalConvenienceFunctions(t *testing.T) {
|
|
globalResolver = nil
|
|
|
|
// test global ResolveAddr
|
|
result := ResolveAddr("127.0.0.1")
|
|
if result == "" {
|
|
t.Error("expected non-empty result from global ResolveAddr")
|
|
}
|
|
|
|
// test global ResolvePort
|
|
port := ResolvePort(80, "tcp")
|
|
if port != "http" {
|
|
t.Errorf("expected 'http', got %q", port)
|
|
}
|
|
|
|
// test global ResolveAddrPort
|
|
addr, portStr := ResolveAddrPort("127.0.0.1", 443, "tcp")
|
|
if addr == "" {
|
|
t.Error("expected non-empty address")
|
|
}
|
|
if portStr != "https" {
|
|
t.Errorf("expected 'https', got %q", portStr)
|
|
}
|
|
|
|
// test global SetNoCache
|
|
SetNoCache(true)
|
|
if !GetGlobalResolver().noCache {
|
|
t.Error("expected global noCache to be true")
|
|
}
|
|
|
|
// reset
|
|
globalResolver = nil
|
|
}
|
|
|
|
func TestConcurrentAccess(t *testing.T) {
|
|
r := New(100 * time.Millisecond)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func(n int) {
|
|
defer wg.Done()
|
|
addr := "127.0.0.1"
|
|
r.ResolveAddr(addr)
|
|
r.ResolvePort(80+n%10, "tcp")
|
|
r.GetCacheSize()
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// should not panic and cache should have entries
|
|
if r.GetCacheSize() == 0 {
|
|
t.Error("expected non-empty cache after concurrent access")
|
|
}
|
|
}
|
|
|
|
func TestResolveAddr_TrailingDot(t *testing.T) {
|
|
// this test verifies the trailing dot removal logic
|
|
// by checking the internal logic works correctly
|
|
r := New(100 * time.Millisecond)
|
|
|
|
// localhost should resolve and have trailing dot removed
|
|
result := r.ResolveAddr("127.0.0.1")
|
|
if len(result) > 0 && result[len(result)-1] == '.' {
|
|
t.Error("expected trailing dot to be removed")
|
|
}
|
|
}
|
|
|