Files
goproxy/internal/dns/config.go
T
2025-03-13 17:53:08 +08:00

152 lines
3.3 KiB
Go

package dns
import (
"bufio"
"encoding/json"
"fmt"
"os"
"regexp"
"strings"
"time"
)
// DNSConfig DNS配置文件结构
type DNSConfig struct {
Records map[string]string `json:"records"` // 普通记录和泛解析记录
Fallback bool `json:"fallback"` // 是否回退到系统DNS
TTL int `json:"ttl"` // 缓存TTL,单位为秒
}
// LoadFromJSON 从JSON文件加载DNS配置
func LoadFromJSON(filePath string) (*DNSConfig, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("打开DNS配置文件失败: %w", err)
}
defer file.Close()
config := &DNSConfig{
Records: make(map[string]string),
Fallback: true,
TTL: 300, // 默认5分钟
}
decoder := json.NewDecoder(file)
if err := decoder.Decode(config); err != nil {
return nil, fmt.Errorf("解析DNS配置文件失败: %w", err)
}
return config, nil
}
// SaveToJSON 将DNS配置保存为JSON文件
func (c *DNSConfig) SaveToJSON(filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("创建DNS配置文件失败: %w", err)
}
defer file.Close()
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(c); err != nil {
return fmt.Errorf("保存DNS配置文件失败: %w", err)
}
return nil
}
// 用于解析hosts文件中的IP:端口格式
var ipPortRegex = regexp.MustCompile(`^([0-9.]+)(?::(\d+))?$`)
// 检查是否为通配符域名
func isWildcardDomain(domain string) bool {
return strings.Contains(domain, "*")
}
// LoadFromHostsFile 从hosts文件格式加载DNS配置
func LoadFromHostsFile(filePath string) (*DNSConfig, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("打开hosts文件失败: %w", err)
}
defer file.Close()
config := &DNSConfig{
Records: make(map[string]string),
Fallback: true,
TTL: 300, // 默认5分钟
}
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// 跳过空行和注释
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue // 行格式不正确,跳过
}
ipPortStr := fields[0]
domains := fields[1:]
// 解析IP和可能的端口
matches := ipPortRegex.FindStringSubmatch(ipPortStr)
if matches == nil {
continue // IP格式不正确,跳过
}
ip := matches[1]
portStr := matches[2]
// 构造记录值
value := ip
if portStr != "" {
value = ip + ":" + portStr
}
for _, domain := range domains {
// 跳过注释
if strings.HasPrefix(domain, "#") {
break
}
// 支持通配符和普通域名
config.Records[domain] = value
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("读取hosts文件失败: %w", err)
}
return config, nil
}
// NewResolverFromConfig 从配置创建解析器
func NewResolverFromConfig(config *DNSConfig) *CustomResolver {
var ttl time.Duration
if config.TTL > 0 {
ttl = time.Duration(config.TTL) * time.Second
} else {
ttl = 5 * time.Minute // 默认5分钟
}
resolver := NewResolver(
WithFallback(config.Fallback),
WithTTL(ttl),
)
// 加载记录
resolver.LoadFromMap(config.Records)
return resolver
}