152 lines
3.3 KiB
Go
152 lines
3.3 KiB
Go
package dns
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// DNSConfig DNS配置文件结构
|
|
type DNSConfig struct {
|
|
Records map[string]string `json:"records"` // 普通记录和泛解析记录
|
|
Fallback bool `json:"fallback"` // 是否回退到系统DNS
|
|
TTL int `json:"ttl"` // 缓存TTL,单位为秒
|
|
}
|
|
|
|
// LoadFromJSON 从JSON文件加载DNS配置
|
|
func LoadFromJSON(filePath string) (*DNSConfig, error) {
|
|
file, err := os.Open(filePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("打开DNS配置文件失败: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
config := &DNSConfig{
|
|
Records: make(map[string]string),
|
|
Fallback: true,
|
|
TTL: 300, // 默认5分钟
|
|
}
|
|
|
|
decoder := json.NewDecoder(file)
|
|
if err := decoder.Decode(config); err != nil {
|
|
return nil, fmt.Errorf("解析DNS配置文件失败: %w", err)
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// SaveToJSON 将DNS配置保存为JSON文件
|
|
func (c *DNSConfig) SaveToJSON(filePath string) error {
|
|
file, err := os.Create(filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("创建DNS配置文件失败: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
encoder := json.NewEncoder(file)
|
|
encoder.SetIndent("", " ")
|
|
if err := encoder.Encode(c); err != nil {
|
|
return fmt.Errorf("保存DNS配置文件失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// 用于解析hosts文件中的IP:端口格式
|
|
var ipPortRegex = regexp.MustCompile(`^([0-9.]+)(?::(\d+))?$`)
|
|
|
|
// 检查是否为通配符域名
|
|
func isWildcardDomain(domain string) bool {
|
|
return strings.Contains(domain, "*")
|
|
}
|
|
|
|
// LoadFromHostsFile 从hosts文件格式加载DNS配置
|
|
func LoadFromHostsFile(filePath string) (*DNSConfig, error) {
|
|
file, err := os.Open(filePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("打开hosts文件失败: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
config := &DNSConfig{
|
|
Records: make(map[string]string),
|
|
Fallback: true,
|
|
TTL: 300, // 默认5分钟
|
|
}
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
lineNum := 0
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// 跳过空行和注释
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 2 {
|
|
continue // 行格式不正确,跳过
|
|
}
|
|
|
|
ipPortStr := fields[0]
|
|
domains := fields[1:]
|
|
|
|
// 解析IP和可能的端口
|
|
matches := ipPortRegex.FindStringSubmatch(ipPortStr)
|
|
if matches == nil {
|
|
continue // IP格式不正确,跳过
|
|
}
|
|
|
|
ip := matches[1]
|
|
portStr := matches[2]
|
|
|
|
// 构造记录值
|
|
value := ip
|
|
if portStr != "" {
|
|
value = ip + ":" + portStr
|
|
}
|
|
|
|
for _, domain := range domains {
|
|
// 跳过注释
|
|
if strings.HasPrefix(domain, "#") {
|
|
break
|
|
}
|
|
|
|
// 支持通配符和普通域名
|
|
config.Records[domain] = value
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("读取hosts文件失败: %w", err)
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// NewResolverFromConfig 从配置创建解析器
|
|
func NewResolverFromConfig(config *DNSConfig) *CustomResolver {
|
|
var ttl time.Duration
|
|
if config.TTL > 0 {
|
|
ttl = time.Duration(config.TTL) * time.Second
|
|
} else {
|
|
ttl = 5 * time.Minute // 默认5分钟
|
|
}
|
|
|
|
resolver := NewResolver(
|
|
WithFallback(config.Fallback),
|
|
WithTTL(ttl),
|
|
)
|
|
|
|
// 加载记录
|
|
resolver.LoadFromMap(config.Records)
|
|
|
|
return resolver
|
|
}
|