mirror of
https://github.com/click33/sa-token-go.git
synced 2026-04-22 21:17:04 +08:00
refactor: remove FINAL_REPORT.md and enhance error handling, configuration, and session management in core components
This commit is contained in:
-127
@@ -1,127 +0,0 @@
|
||||
# 🎉 Sa-Token-Go 项目完成报告
|
||||
|
||||
## ✅ 项目信息
|
||||
|
||||
**项目名称**: Sa-Token-Go
|
||||
**版本**: v0.1.0
|
||||
**作者**: click33
|
||||
**仓库**: https://github.com/click33/sa-token-go
|
||||
**完成日期**: 2025-10-13
|
||||
|
||||
---
|
||||
|
||||
## 🚀 核心功能
|
||||
|
||||
### 1. 超简洁API
|
||||
```go
|
||||
// 一行初始化
|
||||
stputil.SetManager(core.NewBuilder().Storage(memory.NewStorage()).Build())
|
||||
|
||||
// 直接使用
|
||||
stputil.Login(1000)
|
||||
```
|
||||
|
||||
### 2. 注解装饰器
|
||||
```go
|
||||
r.GET("/public", sagin.Ignore(), handler)
|
||||
r.GET("/user", sagin.CheckLogin(), handler)
|
||||
r.GET("/admin", sagin.CheckPermission("admin"), handler)
|
||||
```
|
||||
|
||||
### 3. 异步续签
|
||||
- 性能提升 400%
|
||||
- 响应延迟从 250ms → 50ms
|
||||
- QPS从 2000 → 10000
|
||||
|
||||
### 4. 完整功能
|
||||
40+核心方法,涵盖所有认证授权场景
|
||||
|
||||
---
|
||||
|
||||
## 📂 项目结构
|
||||
|
||||
```
|
||||
sa-token-go/
|
||||
├── core/ # 核心模块
|
||||
│ ├── manager/ # 认证管理器(异步续签)
|
||||
│ ├── builder/ # Builder构建器
|
||||
│ ├── stputil/ # 全局工具类
|
||||
│ └── ...
|
||||
├── storage/
|
||||
│ ├── memory/ # 内存存储
|
||||
│ └── redis/ # Redis存储
|
||||
├── integrations/
|
||||
│ ├── gin/ # Gin集成(含注解)
|
||||
│ ├── echo/ # Echo集成
|
||||
│ ├── fiber/ # Fiber集成
|
||||
│ └── chi/ # Chi集成
|
||||
├── examples/
|
||||
│ ├── quick-start/ # 快速开始
|
||||
│ ├── annotation/ # 注解使用
|
||||
│ └── gin/echo/fiber/chi # 框架集成
|
||||
└── docs/
|
||||
├── tutorial/ # 教程
|
||||
├── guide/ # 使用指南
|
||||
├── api/ # API文档
|
||||
└── design/ # 设计文档
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 项目统计
|
||||
|
||||
| 项目 | 数量 |
|
||||
|------|------|
|
||||
| Go源文件 | 31个 |
|
||||
| 文档文件 | 10个 |
|
||||
| 模块数量 | 13个 |
|
||||
| 核心方法 | 40+ |
|
||||
| 装饰器 | 5个 |
|
||||
| 事件类型 | 8种 |
|
||||
|
||||
---
|
||||
|
||||
## 📚 文档体系
|
||||
|
||||
### 主文档
|
||||
- README.md - 英文
|
||||
- README_zh.md - 中文
|
||||
|
||||
### 详细文档
|
||||
- docs/tutorial/ - 教程
|
||||
- docs/guide/ - 使用指南
|
||||
- docs/api/ - API文档
|
||||
- docs/design/ - 设计文档
|
||||
|
||||
---
|
||||
|
||||
## 🎯 核心优势
|
||||
|
||||
1. **超简洁** - 一行初始化
|
||||
2. **全局工具类** - 无需传递manager
|
||||
3. **装饰器模式** - 类似Java注解
|
||||
4. **异步续签** - 性能提升400%
|
||||
5. **模块化** - 按需导入
|
||||
6. **类型友好** - 支持多种类型
|
||||
|
||||
---
|
||||
|
||||
## 🚀 推送到GitHub
|
||||
|
||||
```bash
|
||||
cd /Users/m1pro/go_project/sa-token-go
|
||||
git init
|
||||
git add .
|
||||
git commit -m "feat: Sa-Token-Go v0.1.0
|
||||
|
||||
- 超简洁API:Builder+StpUtil
|
||||
- 注解装饰器:@SaCheckLogin等
|
||||
- 异步续签:性能提升400%
|
||||
- 完整文档:tutorial/guide/api/design"
|
||||
git remote add origin https://github.com/click33/sa-token-go.git
|
||||
git push -u origin main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**Sa-Token-Go v0.1.0 - 完成!** 🎉
|
||||
+69
-8
@@ -1,34 +1,95 @@
|
||||
package adapter
|
||||
|
||||
// CookieOptions Cookie setting options | Cookie设置选项
|
||||
type CookieOptions struct {
|
||||
// Name Cookie name | Cookie名称
|
||||
Name string
|
||||
// Value Cookie value | Cookie值
|
||||
Value string
|
||||
// MaxAge Cookie expiration time in seconds, 0 means delete cookie, -1 means session cookie | 过期时间(秒),0表示删除cookie,-1表示会话cookie
|
||||
MaxAge int
|
||||
// Path Cookie path | 路径
|
||||
Path string
|
||||
// Domain Cookie domain | 域名
|
||||
Domain string
|
||||
// Secure Only effective under HTTPS | 是否只在HTTPS下生效
|
||||
Secure bool
|
||||
// HttpOnly Prevent JavaScript access | 是否禁止JS访问
|
||||
HttpOnly bool
|
||||
// SameSite SameSite attribute (Strict, Lax, None) | SameSite属性
|
||||
SameSite string
|
||||
}
|
||||
|
||||
// RequestContext defines request context interface for abstracting different web frameworks | 定义请求上下文接口,用于抽象不同Web框架的请求/响应
|
||||
type RequestContext interface {
|
||||
// ============== Request Methods | 请求方法 ==============
|
||||
|
||||
// GetHeader gets request header | 获取请求头
|
||||
GetHeader(key string) string
|
||||
|
||||
// GetHeaders gets all request headers | 获取所有请求头
|
||||
GetHeaders() map[string][]string
|
||||
|
||||
// GetQuery gets query parameter | 获取查询参数
|
||||
GetQuery(key string) string
|
||||
|
||||
// GetQueryAll gets all query parameters | 获取所有查询参数
|
||||
GetQueryAll() map[string][]string
|
||||
|
||||
// GetPostForm gets POST form parameter | 获取POST表单参数
|
||||
GetPostForm(key string) string
|
||||
|
||||
// GetCookie gets cookie | 获取Cookie
|
||||
GetCookie(key string) string
|
||||
|
||||
// SetHeader sets response header | 设置响应头
|
||||
SetHeader(key, value string)
|
||||
|
||||
// SetCookie sets cookie | 设置Cookie
|
||||
SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool)
|
||||
// GetBody gets request body as bytes | 获取请求体字节数据
|
||||
GetBody() ([]byte, error)
|
||||
|
||||
// GetClientIP gets client IP address | 获取客户端IP地址
|
||||
GetClientIP() string
|
||||
|
||||
// GetMethod gets request method | 获取请求方法
|
||||
// GetMethod gets request method (GET, POST, etc.) | 获取请求方法(GET、POST等)
|
||||
GetMethod() string
|
||||
|
||||
// GetPath gets request path | 获取请求路径
|
||||
GetPath() string
|
||||
|
||||
// GetURL gets full request URL | 获取完整请求URL
|
||||
GetURL() string
|
||||
|
||||
// GetUserAgent gets User-Agent header | 获取User-Agent
|
||||
GetUserAgent() string
|
||||
|
||||
// ============== Response Methods | 响应方法 ==============
|
||||
|
||||
// SetHeader sets response header | 设置响应头
|
||||
SetHeader(key, value string)
|
||||
|
||||
// SetCookie sets cookie (legacy method for backward compatibility) | 设置Cookie(兼容旧版本的方法)
|
||||
SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool)
|
||||
|
||||
// SetCookieWithOptions sets cookie with options | 使用选项设置Cookie
|
||||
SetCookieWithOptions(options *CookieOptions)
|
||||
|
||||
// ============== Context Storage Methods | 上下文存储方法 ==============
|
||||
|
||||
// Set sets context value | 设置上下文值
|
||||
Set(key string, value interface{})
|
||||
Set(key string, value any)
|
||||
|
||||
// Get gets context value | 获取上下文值
|
||||
Get(key string) (interface{}, bool)
|
||||
Get(key string) (any, bool)
|
||||
|
||||
// GetString gets string value from context | 从上下文获取字符串值
|
||||
GetString(key string) string
|
||||
|
||||
// MustGet gets context value, panics if not exists | 获取上下文值,不存在则panic
|
||||
MustGet(key string) any
|
||||
|
||||
// ============== Utility Methods | 工具方法 ==============
|
||||
|
||||
// Abort aborts the request processing | 中止请求处理
|
||||
Abort()
|
||||
|
||||
// IsAborted checks if the request is aborted | 检查请求是否已中止
|
||||
IsAborted() bool
|
||||
}
|
||||
|
||||
+17
-8
@@ -4,27 +4,36 @@ import "time"
|
||||
|
||||
// Storage defines storage interface for Token and Session data | 定义存储接口,用于存储Token和Session数据
|
||||
type Storage interface {
|
||||
// ============== Basic Operations | 基本操作 ==============
|
||||
|
||||
// Set sets key-value pair with optional expiration time (0 means never expire) | 设置键值对,可选过期时间(0表示永不过期)
|
||||
Set(key string, value interface{}, expiration time.Duration) error
|
||||
Set(key string, value any, expiration time.Duration) error
|
||||
|
||||
// Get gets value by key | 获取键对应的值
|
||||
Get(key string) (interface{}, error)
|
||||
// Get gets value by key, returns nil if key doesn't exist | 获取键对应的值,键不存在时返回nil
|
||||
Get(key string) (any, error)
|
||||
|
||||
// Delete deletes key | 删除键
|
||||
Delete(key string) error
|
||||
// Delete deletes one or more keys | 删除一个或多个键
|
||||
Delete(keys ...string) error
|
||||
|
||||
// Exists checks if key exists | 检查键是否存在
|
||||
Exists(key string) bool
|
||||
|
||||
// Keys gets all keys matching pattern | 获取匹配模式的所有键
|
||||
// ============== Key Management | 键管理 ==============
|
||||
|
||||
// Keys gets all keys matching pattern (e.g., "user:*") | 获取匹配模式的所有键(如:"user:*")
|
||||
Keys(pattern string) ([]string, error)
|
||||
|
||||
// Expire sets expiration time for key | 设置键的过期时间
|
||||
Expire(key string, expiration time.Duration) error
|
||||
|
||||
// TTL gets remaining time to live | 获取键的剩余生存时间
|
||||
// TTL gets remaining time to live (-1 if no expiration, -2 if key doesn't exist) | 获取键的剩余生存时间(-1表示永不过期,-2表示键不存在)
|
||||
TTL(key string) (time.Duration, error)
|
||||
|
||||
// Clear clears all data (for testing) | 清空所有数据(用于测试)
|
||||
// ============== Utility Methods | 工具方法 ==============
|
||||
|
||||
// Clear clears all data (use with caution, mainly for testing) | 清空所有数据(谨慎使用,主要用于测试)
|
||||
Clear() error
|
||||
|
||||
// Ping checks if storage is accessible | 检查存储是否可访问
|
||||
Ping() error
|
||||
}
|
||||
|
||||
+49
-39
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// Version version number | 版本号
|
||||
const Version = "0.1.0"
|
||||
const Version = "0.1.1"
|
||||
|
||||
// Banner startup banner | 启动横幅
|
||||
const Banner = `
|
||||
@@ -18,9 +18,18 @@ const Banner = `
|
||||
___/ / /_/ / / / / /_/ / ,< / __/ / / /_____/ /_/ / /_/ /
|
||||
/____/\__,_/ /_/ \____/_/|_|\___/_/ /_/ \____/\____/
|
||||
|
||||
:: Sa-Token-Go :: (v%s)
|
||||
:: Sa-Token-Go :: v%s
|
||||
`
|
||||
|
||||
const (
|
||||
boxWidth = 57
|
||||
labelWidth = 16
|
||||
neverExpire = "Never Expire"
|
||||
noLimit = "No Limit"
|
||||
configured = "*** (configured)"
|
||||
secondsFormat = "%d seconds"
|
||||
)
|
||||
|
||||
// Print prints startup banner | 打印启动横幅
|
||||
func Print() {
|
||||
fmt.Printf(Banner, Version)
|
||||
@@ -29,6 +38,36 @@ func Print() {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// formatConfigLine formats a configuration line with proper padding | 格式化配置行
|
||||
func formatConfigLine(label string, value any) string {
|
||||
valueWidth := boxWidth - labelWidth - 5 // 57 - 16 - 5 = 36
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
return fmt.Sprintf("│ %-*s: %-*s │\n", labelWidth, label, valueWidth, valueStr)
|
||||
}
|
||||
|
||||
// formatTimeout formats timeout value (seconds or special text) | 格式化超时时间值
|
||||
func formatTimeout(seconds int64) string {
|
||||
if seconds > 0 {
|
||||
// Also show human-readable format for large values
|
||||
if seconds >= 86400 {
|
||||
days := seconds / 86400
|
||||
return fmt.Sprintf("%d seconds (%d days)", seconds, days)
|
||||
}
|
||||
return fmt.Sprintf(secondsFormat, seconds)
|
||||
} else if seconds == 0 {
|
||||
return neverExpire
|
||||
}
|
||||
return noLimit
|
||||
}
|
||||
|
||||
// formatCount formats count value (number or "No Limit") | 格式化数量值
|
||||
func formatCount(count int) string {
|
||||
if count > 0 {
|
||||
return fmt.Sprintf("%d", count)
|
||||
}
|
||||
return noLimit
|
||||
}
|
||||
|
||||
// PrintWithConfig prints startup banner with full configuration | 打印启动横幅和完整配置信息
|
||||
func PrintWithConfig(cfg *config.Config) {
|
||||
Print()
|
||||
@@ -38,46 +77,17 @@ func PrintWithConfig(cfg *config.Config) {
|
||||
fmt.Println("├─────────────────────────────────────────────────────────┤")
|
||||
|
||||
// Token configuration | Token 配置
|
||||
fmt.Printf("│ Token Name : %-35s │\n", cfg.TokenName)
|
||||
fmt.Printf("│ Token Style : %-35s │\n", cfg.TokenStyle)
|
||||
|
||||
if cfg.Timeout > 0 {
|
||||
fmt.Printf("│ Token Timeout : %-25d seconds │\n", cfg.Timeout)
|
||||
} else {
|
||||
fmt.Printf("│ Token Timeout : %-35s │\n", "Never Expire")
|
||||
}
|
||||
|
||||
if cfg.ActiveTimeout > 0 {
|
||||
fmt.Printf("│ Active Timeout : %-25d seconds │\n", cfg.ActiveTimeout)
|
||||
} else {
|
||||
fmt.Printf("│ Active Timeout : %-35s │\n", "No Limit")
|
||||
}
|
||||
fmt.Print(formatConfigLine("Token Name", cfg.TokenName))
|
||||
fmt.Print(formatConfigLine("Token Style", cfg.TokenStyle))
|
||||
fmt.Print(formatConfigLine("Token Timeout", formatTimeout(cfg.Timeout)))
|
||||
fmt.Print(formatConfigLine("Active Timeout", formatTimeout(cfg.ActiveTimeout)))
|
||||
|
||||
// Login configuration | 登录配置
|
||||
fmt.Printf("│ Auto Renew : %-35v │\n", cfg.AutoRenew)
|
||||
fmt.Printf("│ Concurrent : %-35v │\n", cfg.IsConcurrent)
|
||||
fmt.Printf("│ Share Token : %-35v │\n", cfg.IsShare)
|
||||
|
||||
if cfg.MaxLoginCount > 0 {
|
||||
fmt.Printf("│ Max Login Count : %-35d │\n", cfg.MaxLoginCount)
|
||||
} else {
|
||||
fmt.Printf("│ Max Login Count : %-35s │\n", "No Limit")
|
||||
}
|
||||
|
||||
// Token read source | Token 读取位置
|
||||
fmt.Println("├─────────────────────────────────────────────────────────┤")
|
||||
fmt.Printf("│ Read From Header: %-35v │\n", cfg.IsReadHeader)
|
||||
fmt.Printf("│ Read From Cookie: %-35v │\n", cfg.IsReadCookie)
|
||||
fmt.Printf("│ Read From Body : %-35v │\n", cfg.IsReadBody)
|
||||
|
||||
// Other settings | 其他设置
|
||||
fmt.Println("├─────────────────────────────────────────────────────────┤")
|
||||
|
||||
if cfg.TokenStyle == config.TokenStyleJWT && cfg.JwtSecretKey != "" {
|
||||
fmt.Printf("│ JWT Secret : %-35s │\n", "*** (configured)")
|
||||
}
|
||||
|
||||
fmt.Printf("│ Logging : %-35v │\n", cfg.IsLog)
|
||||
fmt.Print(formatConfigLine("Auto Renew", cfg.AutoRenew))
|
||||
fmt.Print(formatConfigLine("Concurrent", cfg.IsConcurrent))
|
||||
fmt.Print(formatConfigLine("Share Token", cfg.IsShare))
|
||||
fmt.Print(formatConfigLine("Max Login Count", formatCount(cfg.MaxLoginCount)))
|
||||
|
||||
fmt.Println("└─────────────────────────────────────────────────────────┘")
|
||||
fmt.Println()
|
||||
|
||||
@@ -0,0 +1,371 @@
|
||||
package banner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/click33/sa-token-go/core/config"
|
||||
)
|
||||
|
||||
// captureOutput captures stdout output for testing
|
||||
func captureOutput(f func()) string {
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
f()
|
||||
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestPrint(t *testing.T) {
|
||||
output := captureOutput(func() {
|
||||
Print()
|
||||
})
|
||||
|
||||
// Check if output contains expected elements
|
||||
if !strings.Contains(output, "Sa-Token-Go") {
|
||||
t.Error("Output should contain 'Sa-Token-Go'")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, Version) {
|
||||
t.Errorf("Output should contain version %s", Version)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Go Version") {
|
||||
t.Error("Output should contain 'Go Version'")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, runtime.Version()) {
|
||||
t.Errorf("Output should contain Go version %s", runtime.Version())
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "GOOS/GOARCH") {
|
||||
t.Error("Output should contain 'GOOS/GOARCH'")
|
||||
}
|
||||
|
||||
expectedOS := runtime.GOOS + "/" + runtime.GOARCH
|
||||
if !strings.Contains(output, expectedOS) {
|
||||
t.Errorf("Output should contain OS/ARCH %s", expectedOS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
seconds int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Positive seconds less than a day",
|
||||
seconds: 3600,
|
||||
expected: "3600 seconds",
|
||||
},
|
||||
{
|
||||
name: "Exactly one day",
|
||||
seconds: 86400,
|
||||
expected: "86400 seconds (1 days)",
|
||||
},
|
||||
{
|
||||
name: "Multiple days",
|
||||
seconds: 259200, // 3 days
|
||||
expected: "259200 seconds (3 days)",
|
||||
},
|
||||
{
|
||||
name: "30 days",
|
||||
seconds: 2592000,
|
||||
expected: "2592000 seconds (30 days)",
|
||||
},
|
||||
{
|
||||
name: "Zero means never expire",
|
||||
seconds: 0,
|
||||
expected: neverExpire,
|
||||
},
|
||||
{
|
||||
name: "Negative means no limit",
|
||||
seconds: -1,
|
||||
expected: noLimit,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := formatTimeout(tt.seconds)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatTimeout(%d) = %s, want %s", tt.seconds, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Positive count",
|
||||
count: 12,
|
||||
expected: "12",
|
||||
},
|
||||
{
|
||||
name: "Zero means no limit",
|
||||
count: 0,
|
||||
expected: noLimit,
|
||||
},
|
||||
{
|
||||
name: "Negative means no limit",
|
||||
count: -1,
|
||||
expected: noLimit,
|
||||
},
|
||||
{
|
||||
name: "Large count",
|
||||
count: 9999,
|
||||
expected: "9999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := formatCount(tt.count)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatCount(%d) = %s, want %s", tt.count, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConfigLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
label string
|
||||
value any
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "String value",
|
||||
label: "Token Name",
|
||||
value: "sa-token",
|
||||
contains: []string{
|
||||
"Token Name",
|
||||
"sa-token",
|
||||
"│",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Boolean value",
|
||||
label: "Auto Renew",
|
||||
value: true,
|
||||
contains: []string{
|
||||
"Auto Renew",
|
||||
"true",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Integer value",
|
||||
label: "Max Count",
|
||||
value: 12,
|
||||
contains: []string{
|
||||
"Max Count",
|
||||
"12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := formatConfigLine(tt.label, tt.value)
|
||||
for _, s := range tt.contains {
|
||||
if !strings.Contains(result, s) {
|
||||
t.Errorf("formatConfigLine(%s, %v) should contain %s, got: %s", tt.label, tt.value, s, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintWithConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *config.Config
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "Default configuration",
|
||||
config: config.DefaultConfig(),
|
||||
contains: []string{
|
||||
"Configuration",
|
||||
"Token Name",
|
||||
"sa-token",
|
||||
"Token Style",
|
||||
"uuid",
|
||||
"Token Timeout",
|
||||
"30 days",
|
||||
"Auto Renew",
|
||||
"Concurrent",
|
||||
"Share Token",
|
||||
"Max Login Count",
|
||||
"Read From Header",
|
||||
"Read From Cookie",
|
||||
"Read From Body",
|
||||
"Logging",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JWT configuration",
|
||||
config: &config.Config{
|
||||
TokenName: "jwt-token",
|
||||
Timeout: 3600,
|
||||
ActiveTimeout: -1,
|
||||
IsConcurrent: true,
|
||||
IsShare: false,
|
||||
MaxLoginCount: 5,
|
||||
IsReadBody: false,
|
||||
IsReadHeader: true,
|
||||
IsReadCookie: false,
|
||||
TokenStyle: config.TokenStyleJWT,
|
||||
AutoRenew: true,
|
||||
JwtSecretKey: "my-secret-key",
|
||||
IsLog: true,
|
||||
CookieConfig: &config.CookieConfig{
|
||||
Path: "/api",
|
||||
SameSite: config.SameSiteLax,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
},
|
||||
},
|
||||
contains: []string{
|
||||
"jwt-token",
|
||||
"jwt",
|
||||
"3600 seconds",
|
||||
"JWT Secret",
|
||||
"*** (configured)",
|
||||
"Cookie Path",
|
||||
"/api",
|
||||
"Cookie SameSite",
|
||||
"Cookie HttpOnly",
|
||||
"Cookie Secure",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Never expire configuration",
|
||||
config: &config.Config{
|
||||
TokenName: "never-token",
|
||||
Timeout: 0,
|
||||
ActiveTimeout: -1,
|
||||
IsConcurrent: false,
|
||||
IsShare: true,
|
||||
MaxLoginCount: -1,
|
||||
TokenStyle: config.TokenStyleUUID,
|
||||
CookieConfig: &config.CookieConfig{},
|
||||
},
|
||||
contains: []string{
|
||||
"Never Expire",
|
||||
"No Limit",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JWT without secret key",
|
||||
config: &config.Config{
|
||||
TokenName: "jwt-token",
|
||||
TokenStyle: config.TokenStyleJWT,
|
||||
JwtSecretKey: "",
|
||||
CookieConfig: &config.CookieConfig{},
|
||||
},
|
||||
contains: []string{
|
||||
"JWT Secret",
|
||||
"Not Set",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
output := captureOutput(func() {
|
||||
PrintWithConfig(tt.config)
|
||||
})
|
||||
|
||||
for _, s := range tt.contains {
|
||||
if !strings.Contains(output, s) {
|
||||
t.Errorf("PrintWithConfig() output should contain '%s'\nGot output:\n%s", s, output)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for box drawing characters
|
||||
if !strings.Contains(output, "┌") || !strings.Contains(output, "└") {
|
||||
t.Error("Output should contain box drawing characters")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintWithConfigNilCookie(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
TokenName: "test-token",
|
||||
TokenStyle: config.TokenStyleSimple,
|
||||
CookieConfig: nil, // nil cookie config
|
||||
}
|
||||
|
||||
output := captureOutput(func() {
|
||||
PrintWithConfig(cfg)
|
||||
})
|
||||
|
||||
// Should not panic and should not contain cookie configuration
|
||||
if strings.Contains(output, "Cookie Path") {
|
||||
t.Error("Output should not contain Cookie configuration when CookieConfig is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPrint(b *testing.B) {
|
||||
// Redirect output to discard
|
||||
old := os.Stdout
|
||||
os.Stdout, _ = os.Open(os.DevNull)
|
||||
defer func() { os.Stdout = old }()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
Print()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPrintWithConfig(b *testing.B) {
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// Redirect output to discard
|
||||
old := os.Stdout
|
||||
os.Stdout, _ = os.Open(os.DevNull)
|
||||
defer func() { os.Stdout = old }()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
PrintWithConfig(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFormatTimeout(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
formatTimeout(2592000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFormatCount(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
formatCount(12)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFormatConfigLine(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
formatConfigLine("Token Name", "sa-token")
|
||||
}
|
||||
}
|
||||
+165
-41
@@ -1,6 +1,7 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/click33/sa-token-go/core/adapter"
|
||||
@@ -11,39 +12,52 @@ import (
|
||||
|
||||
// Builder Sa-Token builder for fluent configuration | Sa-Token构建器,用于流式配置
|
||||
type Builder struct {
|
||||
storage adapter.Storage
|
||||
tokenName string
|
||||
timeout int64
|
||||
activeTimeout int64
|
||||
isConcurrent bool
|
||||
isShare bool
|
||||
maxLoginCount int
|
||||
tokenStyle config.TokenStyle
|
||||
autoRenew bool
|
||||
jwtSecretKey string
|
||||
isLog bool
|
||||
isPrintBanner bool
|
||||
isReadBody bool
|
||||
isReadHeader bool
|
||||
isReadCookie bool
|
||||
storage adapter.Storage
|
||||
tokenName string
|
||||
timeout int64
|
||||
activeTimeout int64
|
||||
isConcurrent bool
|
||||
isShare bool
|
||||
maxLoginCount int
|
||||
tokenStyle config.TokenStyle
|
||||
autoRenew bool
|
||||
jwtSecretKey string
|
||||
isLog bool
|
||||
isPrintBanner bool
|
||||
isReadBody bool
|
||||
isReadHeader bool
|
||||
isReadCookie bool
|
||||
dataRefreshPeriod int64
|
||||
tokenSessionCheckLogin bool
|
||||
cookieConfig *config.CookieConfig
|
||||
}
|
||||
|
||||
// NewBuilder creates a new builder | 创建新的构建器
|
||||
// NewBuilder creates a new builder with default configuration | 创建新的构建器(使用默认配置)
|
||||
func NewBuilder() *Builder {
|
||||
return &Builder{
|
||||
tokenName: "satoken",
|
||||
timeout: 2592000, // 30 days | 30天
|
||||
activeTimeout: -1,
|
||||
isConcurrent: true,
|
||||
isShare: true,
|
||||
maxLoginCount: 12,
|
||||
tokenStyle: config.TokenStyleUUID,
|
||||
autoRenew: true,
|
||||
isLog: false,
|
||||
isPrintBanner: true, // Print banner by default | 默认打印 Banner
|
||||
isReadBody: false, // Don't read from body by default | 默认不从 Body 读取
|
||||
isReadHeader: true, // Read from header by default | 默认从 Header 读取
|
||||
isReadCookie: false, // Don't read from cookie by default | 默认不从 Cookie 读取
|
||||
tokenName: config.DefaultTokenName,
|
||||
timeout: config.DefaultTimeout,
|
||||
activeTimeout: config.NoLimit,
|
||||
isConcurrent: true,
|
||||
isShare: true,
|
||||
maxLoginCount: config.DefaultMaxLoginCount,
|
||||
tokenStyle: config.TokenStyleUUID,
|
||||
autoRenew: true,
|
||||
isLog: false,
|
||||
isPrintBanner: true,
|
||||
isReadBody: false,
|
||||
isReadHeader: true,
|
||||
isReadCookie: false,
|
||||
dataRefreshPeriod: config.NoLimit,
|
||||
tokenSessionCheckLogin: true,
|
||||
cookieConfig: &config.CookieConfig{
|
||||
Domain: "",
|
||||
Path: config.DefaultCookiePath,
|
||||
Secure: false,
|
||||
HttpOnly: true,
|
||||
SameSite: config.SameSiteLax,
|
||||
MaxAge: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,10 +157,122 @@ func (b *Builder) IsReadCookie(isRead bool) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// DataRefreshPeriod sets data refresh period | 设置数据刷新周期
|
||||
func (b *Builder) DataRefreshPeriod(seconds int64) *Builder {
|
||||
b.dataRefreshPeriod = seconds
|
||||
return b
|
||||
}
|
||||
|
||||
// TokenSessionCheckLogin sets whether to check token session on login | 设置登录时是否检查Token会话
|
||||
func (b *Builder) TokenSessionCheckLogin(check bool) *Builder {
|
||||
b.tokenSessionCheckLogin = check
|
||||
return b
|
||||
}
|
||||
|
||||
// CookieDomain sets cookie domain | 设置Cookie域名
|
||||
func (b *Builder) CookieDomain(domain string) *Builder {
|
||||
if b.cookieConfig == nil {
|
||||
b.cookieConfig = &config.CookieConfig{}
|
||||
}
|
||||
b.cookieConfig.Domain = domain
|
||||
return b
|
||||
}
|
||||
|
||||
// CookiePath sets cookie path | 设置Cookie路径
|
||||
func (b *Builder) CookiePath(path string) *Builder {
|
||||
if b.cookieConfig == nil {
|
||||
b.cookieConfig = &config.CookieConfig{}
|
||||
}
|
||||
b.cookieConfig.Path = path
|
||||
return b
|
||||
}
|
||||
|
||||
// CookieSecure sets cookie secure flag | 设置Cookie的Secure标志
|
||||
func (b *Builder) CookieSecure(secure bool) *Builder {
|
||||
if b.cookieConfig == nil {
|
||||
b.cookieConfig = &config.CookieConfig{}
|
||||
}
|
||||
b.cookieConfig.Secure = secure
|
||||
return b
|
||||
}
|
||||
|
||||
// CookieHttpOnly sets cookie httpOnly flag | 设置Cookie的HttpOnly标志
|
||||
func (b *Builder) CookieHttpOnly(httpOnly bool) *Builder {
|
||||
if b.cookieConfig == nil {
|
||||
b.cookieConfig = &config.CookieConfig{}
|
||||
}
|
||||
b.cookieConfig.HttpOnly = httpOnly
|
||||
return b
|
||||
}
|
||||
|
||||
// CookieSameSite sets cookie sameSite attribute | 设置Cookie的SameSite属性
|
||||
func (b *Builder) CookieSameSite(sameSite config.SameSiteMode) *Builder {
|
||||
if b.cookieConfig == nil {
|
||||
b.cookieConfig = &config.CookieConfig{}
|
||||
}
|
||||
b.cookieConfig.SameSite = sameSite
|
||||
return b
|
||||
}
|
||||
|
||||
// CookieMaxAge sets cookie max age | 设置Cookie的最大年龄
|
||||
func (b *Builder) CookieMaxAge(maxAge int) *Builder {
|
||||
if b.cookieConfig == nil {
|
||||
b.cookieConfig = &config.CookieConfig{}
|
||||
}
|
||||
b.cookieConfig.MaxAge = maxAge
|
||||
return b
|
||||
}
|
||||
|
||||
// CookieConfig sets complete cookie configuration | 设置完整的Cookie配置
|
||||
func (b *Builder) CookieConfig(cfg *config.CookieConfig) *Builder {
|
||||
b.cookieConfig = cfg
|
||||
return b
|
||||
}
|
||||
|
||||
// NeverExpire sets token to never expire | 设置Token永不过期
|
||||
func (b *Builder) NeverExpire() *Builder {
|
||||
b.timeout = config.NoLimit
|
||||
return b
|
||||
}
|
||||
|
||||
// NoActiveTimeout disables active timeout | 禁用活跃超时
|
||||
func (b *Builder) NoActiveTimeout() *Builder {
|
||||
b.activeTimeout = config.NoLimit
|
||||
return b
|
||||
}
|
||||
|
||||
// UnlimitedLogin allows unlimited concurrent logins | 允许无限并发登录
|
||||
func (b *Builder) UnlimitedLogin() *Builder {
|
||||
b.maxLoginCount = config.NoLimit
|
||||
return b
|
||||
}
|
||||
|
||||
// Validate validates the builder configuration | 验证构建器配置
|
||||
func (b *Builder) Validate() error {
|
||||
if b.storage == nil {
|
||||
return fmt.Errorf("storage is required, please call Storage() method")
|
||||
}
|
||||
|
||||
if b.tokenName == "" {
|
||||
return fmt.Errorf("tokenName cannot be empty")
|
||||
}
|
||||
|
||||
if b.tokenStyle == config.TokenStyleJWT && b.jwtSecretKey == "" {
|
||||
return fmt.Errorf("jwtSecretKey is required when TokenStyle is JWT")
|
||||
}
|
||||
|
||||
if !b.isReadHeader && !b.isReadCookie && !b.isReadBody {
|
||||
return fmt.Errorf("at least one of IsReadHeader, IsReadCookie, or IsReadBody must be true")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build builds Manager and prints startup banner | 构建Manager并打印启动Banner
|
||||
func (b *Builder) Build() *manager.Manager {
|
||||
if b.storage == nil {
|
||||
panic("storage is required, please call Storage() method")
|
||||
// Validate configuration | 验证配置
|
||||
if err := b.Validate(); err != nil {
|
||||
panic(fmt.Sprintf("invalid configuration: %v", err))
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
@@ -160,20 +286,13 @@ func (b *Builder) Build() *manager.Manager {
|
||||
IsReadHeader: b.isReadHeader,
|
||||
IsReadCookie: b.isReadCookie,
|
||||
TokenStyle: b.tokenStyle,
|
||||
DataRefreshPeriod: -1,
|
||||
TokenSessionCheckLogin: true,
|
||||
DataRefreshPeriod: b.dataRefreshPeriod,
|
||||
TokenSessionCheckLogin: b.tokenSessionCheckLogin,
|
||||
AutoRenew: b.autoRenew,
|
||||
JwtSecretKey: b.jwtSecretKey,
|
||||
IsLog: b.isLog,
|
||||
IsPrintBanner: b.isPrintBanner,
|
||||
CookieConfig: &config.CookieConfig{
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
Secure: false,
|
||||
HttpOnly: true,
|
||||
SameSite: "Lax",
|
||||
MaxAge: 0,
|
||||
},
|
||||
CookieConfig: b.cookieConfig,
|
||||
}
|
||||
|
||||
// Print startup banner with full configuration | 打印启动Banner和完整配置
|
||||
@@ -189,3 +308,8 @@ func (b *Builder) Build() *manager.Manager {
|
||||
|
||||
return mgr
|
||||
}
|
||||
|
||||
// MustBuild builds Manager and panics if validation fails | 构建Manager,验证失败时panic
|
||||
func (b *Builder) MustBuild() *manager.Manager {
|
||||
return b.Build()
|
||||
}
|
||||
|
||||
+141
-18
@@ -1,5 +1,7 @@
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
// TokenStyle Token generation style | Token生成风格
|
||||
type TokenStyle string
|
||||
|
||||
@@ -24,6 +26,39 @@ const (
|
||||
TokenStyleTik TokenStyle = "tik"
|
||||
)
|
||||
|
||||
// SameSiteMode Cookie SameSite attribute values | Cookie的SameSite属性值
|
||||
type SameSiteMode string
|
||||
|
||||
const (
|
||||
// SameSiteStrict Strict mode | 严格模式
|
||||
SameSiteStrict SameSiteMode = "Strict"
|
||||
// SameSiteLax Lax mode | 宽松模式
|
||||
SameSiteLax SameSiteMode = "Lax"
|
||||
// SameSiteNone None mode | 无限制模式
|
||||
SameSiteNone SameSiteMode = "None"
|
||||
)
|
||||
|
||||
// Default configuration constants | 默认配置常量
|
||||
const (
|
||||
DefaultTokenName = "satoken"
|
||||
DefaultTimeout = 2592000 // 30 days in seconds | 30天(秒)
|
||||
DefaultMaxLoginCount = 12 // Maximum concurrent logins | 最大并发登录数
|
||||
DefaultCookiePath = "/"
|
||||
NoLimit = -1 // No limit flag | 不限制标志
|
||||
)
|
||||
|
||||
// IsValid checks if the TokenStyle is valid | 检查TokenStyle是否有效
|
||||
func (ts TokenStyle) IsValid() bool {
|
||||
switch ts {
|
||||
case TokenStyleUUID, TokenStyleSimple, TokenStyleRandom32,
|
||||
TokenStyleRandom64, TokenStyleRandom128, TokenStyleJWT,
|
||||
TokenStyleHash, TokenStyleTimestamp, TokenStyleTik:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Config Sa-Token configuration | Sa-Token配置
|
||||
type Config struct {
|
||||
// TokenName Token name (also used as Cookie name) | Token名称(同时也是Cookie名称)
|
||||
@@ -93,7 +128,7 @@ type CookieConfig struct {
|
||||
HttpOnly bool
|
||||
|
||||
// SameSite SameSite attribute (Strict, Lax, None) | SameSite属性(Strict、Lax、None)
|
||||
SameSite string
|
||||
SameSite SameSiteMode
|
||||
|
||||
// MaxAge Cookie expiration time in seconds | 过期时间(单位:秒)
|
||||
MaxAge int
|
||||
@@ -102,33 +137,73 @@ type CookieConfig struct {
|
||||
// DefaultConfig Returns default configuration | 返回默认配置
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
TokenName: "sa-token",
|
||||
Timeout: 2592000, // 30 days | 30天
|
||||
ActiveTimeout: -1, // No limit | 不限制
|
||||
IsConcurrent: true, // Allow concurrent login | 允许并发登录
|
||||
IsShare: true, // Share Token | 共享Token
|
||||
MaxLoginCount: 12, // Max 12 logins | 最多12个
|
||||
IsReadBody: false, // Don't read from Body (default) | 不从Body读取(默认)
|
||||
IsReadHeader: true, // Read from Header (recommended) | 从Header读取(推荐)
|
||||
IsReadCookie: false, // Don't read from Cookie (default) | 不从Cookie读取(默认)
|
||||
TokenName: DefaultTokenName,
|
||||
Timeout: DefaultTimeout,
|
||||
ActiveTimeout: NoLimit,
|
||||
IsConcurrent: true,
|
||||
IsShare: true,
|
||||
MaxLoginCount: DefaultMaxLoginCount,
|
||||
IsReadBody: false,
|
||||
IsReadHeader: true,
|
||||
IsReadCookie: false,
|
||||
TokenStyle: TokenStyleUUID,
|
||||
DataRefreshPeriod: -1, // No auto-refresh | 不自动续签
|
||||
TokenSessionCheckLogin: true, // Check on login | 登录时检查
|
||||
AutoRenew: true, // Auto-renew | 自动续期
|
||||
JwtSecretKey: "", // Empty by default | 默认空
|
||||
IsLog: false, // No logging | 不输出日志
|
||||
IsPrintBanner: true, // Print startup banner | 打印启动 Banner
|
||||
DataRefreshPeriod: NoLimit,
|
||||
TokenSessionCheckLogin: true,
|
||||
AutoRenew: true,
|
||||
JwtSecretKey: "",
|
||||
IsLog: false,
|
||||
IsPrintBanner: true,
|
||||
CookieConfig: &CookieConfig{
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
Path: DefaultCookiePath,
|
||||
Secure: false,
|
||||
HttpOnly: true,
|
||||
SameSite: "Lax",
|
||||
SameSite: SameSiteLax,
|
||||
MaxAge: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the configuration | 验证配置是否合理
|
||||
func (c *Config) Validate() error {
|
||||
// Check TokenName
|
||||
if c.TokenName == "" {
|
||||
return fmt.Errorf("TokenName cannot be empty")
|
||||
}
|
||||
|
||||
// Check TokenStyle
|
||||
if !c.TokenStyle.IsValid() {
|
||||
return fmt.Errorf("invalid TokenStyle: %s", c.TokenStyle)
|
||||
}
|
||||
|
||||
// Check JWT secret key when using JWT style
|
||||
if c.TokenStyle == TokenStyleJWT && c.JwtSecretKey == "" {
|
||||
return fmt.Errorf("JwtSecretKey is required when TokenStyle is JWT")
|
||||
}
|
||||
|
||||
// Check Timeout
|
||||
if c.Timeout < NoLimit {
|
||||
return fmt.Errorf("Timeout must be >= -1, got: %d", c.Timeout)
|
||||
}
|
||||
|
||||
// Check ActiveTimeout
|
||||
if c.ActiveTimeout < NoLimit {
|
||||
return fmt.Errorf("ActiveTimeout must be >= -1, got: %d", c.ActiveTimeout)
|
||||
}
|
||||
|
||||
// Check MaxLoginCount
|
||||
if c.MaxLoginCount < NoLimit {
|
||||
return fmt.Errorf("MaxLoginCount must be >= -1, got: %d", c.MaxLoginCount)
|
||||
}
|
||||
|
||||
// Check if at least one read source is enabled
|
||||
if !c.IsReadHeader && !c.IsReadCookie && !c.IsReadBody {
|
||||
return fmt.Errorf("at least one of IsReadHeader, IsReadCookie, or IsReadBody must be true")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone Clone configuration | 克隆配置
|
||||
func (c *Config) Clone() *Config {
|
||||
newConfig := *c
|
||||
@@ -169,12 +244,48 @@ func (c *Config) SetIsShare(isShare bool) *Config {
|
||||
return c
|
||||
}
|
||||
|
||||
// SetMaxLoginCount Set maximum login count | 设置最大登录数量
|
||||
func (c *Config) SetMaxLoginCount(count int) *Config {
|
||||
c.MaxLoginCount = count
|
||||
return c
|
||||
}
|
||||
|
||||
// SetIsReadBody Set whether to read Token from body | 设置是否从请求体读取Token
|
||||
func (c *Config) SetIsReadBody(isReadBody bool) *Config {
|
||||
c.IsReadBody = isReadBody
|
||||
return c
|
||||
}
|
||||
|
||||
// SetIsReadHeader Set whether to read Token from header | 设置是否从Header读取Token
|
||||
func (c *Config) SetIsReadHeader(isReadHeader bool) *Config {
|
||||
c.IsReadHeader = isReadHeader
|
||||
return c
|
||||
}
|
||||
|
||||
// SetIsReadCookie Set whether to read Token from cookie | 设置是否从Cookie读取Token
|
||||
func (c *Config) SetIsReadCookie(isReadCookie bool) *Config {
|
||||
c.IsReadCookie = isReadCookie
|
||||
return c
|
||||
}
|
||||
|
||||
// SetTokenStyle Set Token generation style | 设置Token风格
|
||||
func (c *Config) SetTokenStyle(style TokenStyle) *Config {
|
||||
c.TokenStyle = style
|
||||
return c
|
||||
}
|
||||
|
||||
// SetDataRefreshPeriod Set data refresh period | 设置数据刷新周期
|
||||
func (c *Config) SetDataRefreshPeriod(period int64) *Config {
|
||||
c.DataRefreshPeriod = period
|
||||
return c
|
||||
}
|
||||
|
||||
// SetTokenSessionCheckLogin Set whether to check token session on login | 设置登录时是否检查token会话
|
||||
func (c *Config) SetTokenSessionCheckLogin(check bool) *Config {
|
||||
c.TokenSessionCheckLogin = check
|
||||
return c
|
||||
}
|
||||
|
||||
// SetJwtSecretKey Set JWT secret key | 设置JWT密钥
|
||||
func (c *Config) SetJwtSecretKey(key string) *Config {
|
||||
c.JwtSecretKey = key
|
||||
@@ -192,3 +303,15 @@ func (c *Config) SetIsLog(isLog bool) *Config {
|
||||
c.IsLog = isLog
|
||||
return c
|
||||
}
|
||||
|
||||
// SetIsPrintBanner Set whether to print banner | 设置是否打印Banner
|
||||
func (c *Config) SetIsPrintBanner(isPrint bool) *Config {
|
||||
c.IsPrintBanner = isPrint
|
||||
return c
|
||||
}
|
||||
|
||||
// SetCookieConfig Set cookie configuration | 设置Cookie配置
|
||||
func (c *Config) SetCookieConfig(cookieConfig *CookieConfig) *Config {
|
||||
c.CookieConfig = cookieConfig
|
||||
return c
|
||||
}
|
||||
+31
-13
@@ -1,10 +1,17 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/click33/sa-token-go/core/adapter"
|
||||
"github.com/click33/sa-token-go/core/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
bearerPrefix = "Bearer "
|
||||
authHeader = "Authorization"
|
||||
)
|
||||
|
||||
// SaTokenContext Sa-Token context for current request | Sa-Token上下文,用于当前请求
|
||||
type SaTokenContext struct {
|
||||
ctx adapter.RequestContext
|
||||
@@ -19,38 +26,49 @@ func NewContext(ctx adapter.RequestContext, mgr *manager.Manager) *SaTokenContex
|
||||
}
|
||||
}
|
||||
|
||||
// extractBearerToken 从 Authorization 头中提取 Bearer Token
|
||||
func extractBearerToken(auth string) string {
|
||||
auth = strings.TrimSpace(auth)
|
||||
if auth == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 支持大小写不敏感的 Bearer 前缀
|
||||
if len(auth) > 7 && strings.EqualFold(auth[:7], bearerPrefix) {
|
||||
return strings.TrimSpace(auth[7:])
|
||||
}
|
||||
|
||||
return auth
|
||||
}
|
||||
|
||||
// GetTokenValue gets token value from current request | 获取当前请求的Token值
|
||||
func (c *SaTokenContext) GetTokenValue() string {
|
||||
cfg := c.manager.GetConfig()
|
||||
|
||||
// 1. 尝试从Header获取
|
||||
if cfg.IsReadHeader {
|
||||
token := c.ctx.GetHeader(cfg.TokenName)
|
||||
if token != "" {
|
||||
// 从自定义 token 名称的 Header 获取
|
||||
if token := strings.TrimSpace(c.ctx.GetHeader(cfg.TokenName)); token != "" {
|
||||
return token
|
||||
}
|
||||
// 也尝试从Authorization头获取
|
||||
auth := c.ctx.GetHeader("Authorization")
|
||||
if auth != "" {
|
||||
// 移除 "Bearer " 前缀
|
||||
if len(auth) > 7 && auth[:7] == "Bearer " {
|
||||
return auth[7:]
|
||||
|
||||
// 从 Authorization 头获取
|
||||
if auth := c.ctx.GetHeader(authHeader); auth != "" {
|
||||
if token := extractBearerToken(auth); token != "" {
|
||||
return token
|
||||
}
|
||||
return auth
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 尝试从Cookie获取
|
||||
if cfg.IsReadCookie {
|
||||
token := c.ctx.GetCookie(cfg.TokenName)
|
||||
if token != "" {
|
||||
if token := strings.TrimSpace(c.ctx.GetCookie(cfg.TokenName)); token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 尝试从Query参数获取
|
||||
token := c.ctx.GetQuery(cfg.TokenName)
|
||||
if token != "" {
|
||||
if token := strings.TrimSpace(c.ctx.GetQuery(cfg.TokenName)); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
|
||||
+127
-29
@@ -1,10 +1,15 @@
|
||||
package core
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Common error definitions for better error handling and internationalization support
|
||||
// 常见错误定义,用于更好的错误处理和国际化支持
|
||||
|
||||
// ============ Authentication Errors | 认证错误 ============
|
||||
|
||||
var (
|
||||
// ErrNotLogin indicates the user is not logged in | 用户未登录错误
|
||||
ErrNotLogin = fmt.Errorf("authentication required: user not logged in")
|
||||
@@ -15,20 +20,38 @@ var (
|
||||
// ErrTokenExpired indicates the token has expired | Token已过期
|
||||
ErrTokenExpired = fmt.Errorf("token expired: please login again to get a new token")
|
||||
|
||||
// ErrAccountDisabled indicates the account has been disabled or banned | 账号已被禁用
|
||||
ErrAccountDisabled = fmt.Errorf("account disabled: this account has been temporarily or permanently disabled")
|
||||
// ErrInvalidLoginID indicates the login ID is invalid | 登录ID无效
|
||||
ErrInvalidLoginID = fmt.Errorf("invalid login ID: the login identifier cannot be empty")
|
||||
|
||||
// ErrInvalidDevice indicates the device identifier is invalid | 设备标识无效
|
||||
ErrInvalidDevice = fmt.Errorf("invalid device: the device identifier is not valid")
|
||||
)
|
||||
|
||||
// ============ Authorization Errors | 授权错误 ============
|
||||
|
||||
var (
|
||||
// ErrPermissionDenied indicates insufficient permissions | 权限不足
|
||||
ErrPermissionDenied = fmt.Errorf("permission denied: you don't have the required permission")
|
||||
|
||||
// ErrRoleDenied indicates insufficient role | 角色权限不足
|
||||
ErrRoleDenied = fmt.Errorf("role denied: you don't have the required role")
|
||||
)
|
||||
|
||||
// ErrSessionNotFound indicates the session doesn't exist | Session不存在
|
||||
ErrSessionNotFound = fmt.Errorf("session not found: the session may have expired or been deleted")
|
||||
// ============ Account Errors | 账号错误 ============
|
||||
|
||||
var (
|
||||
// ErrAccountDisabled indicates the account has been disabled or banned | 账号已被禁用
|
||||
ErrAccountDisabled = fmt.Errorf("account disabled: this account has been temporarily or permanently disabled")
|
||||
|
||||
// ErrAccountNotFound indicates the account doesn't exist | 账号不存在
|
||||
ErrAccountNotFound = fmt.Errorf("account not found: no account associated with this identifier")
|
||||
)
|
||||
|
||||
// ============ Session Errors | 会话错误 ============
|
||||
|
||||
var (
|
||||
// ErrSessionNotFound indicates the session doesn't exist | Session不存在
|
||||
ErrSessionNotFound = fmt.Errorf("session not found: the session may have expired or been deleted")
|
||||
|
||||
// ErrKickedOut indicates the user has been kicked out | 用户已被踢下线
|
||||
ErrKickedOut = fmt.Errorf("kicked out: this session has been forcibly terminated")
|
||||
@@ -38,26 +61,26 @@ var (
|
||||
|
||||
// ErrMaxLoginCount indicates maximum concurrent login limit reached | 达到最大登录数量限制
|
||||
ErrMaxLoginCount = fmt.Errorf("max login limit: maximum number of concurrent logins reached")
|
||||
|
||||
// ErrStorageUnavailable indicates the storage backend is unavailable | 存储后端不可用
|
||||
ErrStorageUnavailable = fmt.Errorf("storage unavailable: unable to connect to storage backend")
|
||||
|
||||
// ErrInvalidLoginID indicates the login ID is invalid | 登录ID无效
|
||||
ErrInvalidLoginID = fmt.Errorf("invalid login ID: the login identifier cannot be empty")
|
||||
|
||||
// ErrInvalidDevice indicates the device identifier is invalid | 设备标识无效
|
||||
ErrInvalidDevice = fmt.Errorf("invalid device: the device identifier is not valid")
|
||||
)
|
||||
|
||||
// SaTokenError represents a custom error with error code and context | 自定义错误类型,包含错误码和上下文信息
|
||||
// ============ System Errors | 系统错误 ============
|
||||
|
||||
var (
|
||||
// ErrStorageUnavailable indicates the storage backend is unavailable | 存储后端不可用
|
||||
ErrStorageUnavailable = fmt.Errorf("storage unavailable: unable to connect to storage backend")
|
||||
)
|
||||
|
||||
// ============ Custom Error Type | 自定义错误类型 ============
|
||||
|
||||
// SaTokenError Represents a custom error with error code and context | 自定义错误类型,包含错误码和上下文信息
|
||||
type SaTokenError struct {
|
||||
Code int // Error code for programmatic handling | 错误码,用于程序化处理
|
||||
Message string // Human-readable error message | 可读的错误消息
|
||||
Err error // Underlying error (if any) | 底层错误(如果有)
|
||||
Context map[string]interface{} // Additional context information | 额外的上下文信息
|
||||
Code int // Error code for programmatic handling | 错误码,用于程序化处理
|
||||
Message string // Human-readable error message | 可读的错误消息
|
||||
Err error // Underlying error (if any) | 底层错误(如果有)
|
||||
Context map[string]any // Additional context information | 额外的上下文信息
|
||||
}
|
||||
|
||||
// Error implements the error interface | 实现 error 接口
|
||||
// Error Implements the error interface | 实现 error 接口
|
||||
func (e *SaTokenError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s (code: %d): %v", e.Message, e.Code, e.Err)
|
||||
@@ -65,32 +88,52 @@ func (e *SaTokenError) Error() string {
|
||||
return fmt.Sprintf("%s (code: %d)", e.Message, e.Code)
|
||||
}
|
||||
|
||||
// Unwrap implements the unwrap interface for error chains | 实现 unwrap 接口,支持错误链
|
||||
// Unwrap Implements the unwrap interface for error chains | 实现 unwrap 接口,支持错误链
|
||||
func (e *SaTokenError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// WithContext adds context information to the error | 为错误添加上下文信息
|
||||
func (e *SaTokenError) WithContext(key string, value interface{}) *SaTokenError {
|
||||
// WithContext Adds context information to the error | 为错误添加上下文信息
|
||||
func (e *SaTokenError) WithContext(key string, value any) *SaTokenError {
|
||||
if e.Context == nil {
|
||||
e.Context = make(map[string]interface{})
|
||||
e.Context = make(map[string]any)
|
||||
}
|
||||
e.Context[key] = value
|
||||
return e
|
||||
}
|
||||
|
||||
// NewError creates a new Sa-Token error | 创建新的 Sa-Token 错误
|
||||
// GetContext Gets context value | 获取上下文值
|
||||
func (e *SaTokenError) GetContext(key string) (any, bool) {
|
||||
if e.Context == nil {
|
||||
return nil, false
|
||||
}
|
||||
val, exists := e.Context[key]
|
||||
return val, exists
|
||||
}
|
||||
|
||||
// Is Implements errors.Is for error comparison | 实现 errors.Is 进行错误比较
|
||||
func (e *SaTokenError) Is(target error) bool {
|
||||
t, ok := target.(*SaTokenError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return e.Code == t.Code
|
||||
}
|
||||
|
||||
// ============ Error Constructors | 错误构造函数 ============
|
||||
|
||||
// NewError Creates a new Sa-Token error | 创建新的 Sa-Token 错误
|
||||
func NewError(code int, message string, err error) *SaTokenError {
|
||||
return &SaTokenError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Err: err,
|
||||
Context: make(map[string]interface{}),
|
||||
Context: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorWithContext creates a new Sa-Token error with context | 创建带上下文的 Sa-Token 错误
|
||||
func NewErrorWithContext(code int, message string, err error, context map[string]interface{}) *SaTokenError {
|
||||
// NewErrorWithContext Creates a new Sa-Token error with context | 创建带上下文的 Sa-Token 错误
|
||||
func NewErrorWithContext(code int, message string, err error, context map[string]any) *SaTokenError {
|
||||
return &SaTokenError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
@@ -99,7 +142,62 @@ func NewErrorWithContext(code int, message string, err error, context map[string
|
||||
}
|
||||
}
|
||||
|
||||
// Error code definitions | 错误码定义
|
||||
// NewNotLoginError Creates a not login error | 创建未登录错误
|
||||
func NewNotLoginError() *SaTokenError {
|
||||
return NewError(CodeNotLogin, "user not logged in", ErrNotLogin)
|
||||
}
|
||||
|
||||
// NewPermissionDeniedError Creates a permission denied error | 创建权限拒绝错误
|
||||
func NewPermissionDeniedError(permission string) *SaTokenError {
|
||||
return NewError(CodePermissionDenied, "permission denied", ErrPermissionDenied).
|
||||
WithContext("permission", permission)
|
||||
}
|
||||
|
||||
// NewRoleDeniedError Creates a role denied error | 创建角色拒绝错误
|
||||
func NewRoleDeniedError(role string) *SaTokenError {
|
||||
return NewError(CodePermissionDenied, "role denied", ErrRoleDenied).
|
||||
WithContext("role", role)
|
||||
}
|
||||
|
||||
// NewAccountDisabledError Creates an account disabled error | 创建账号禁用错误
|
||||
func NewAccountDisabledError(loginID string) *SaTokenError {
|
||||
return NewError(CodeAccountDisabled, "account disabled", ErrAccountDisabled).
|
||||
WithContext("loginID", loginID)
|
||||
}
|
||||
|
||||
// ============ Error Checking Helpers | 错误检查辅助函数 ============
|
||||
|
||||
// IsNotLoginError Checks if error is a not login error | 检查是否为未登录错误
|
||||
func IsNotLoginError(err error) bool {
|
||||
return errors.Is(err, ErrNotLogin)
|
||||
}
|
||||
|
||||
// IsPermissionDeniedError Checks if error is a permission denied error | 检查是否为权限拒绝错误
|
||||
func IsPermissionDeniedError(err error) bool {
|
||||
return errors.Is(err, ErrPermissionDenied)
|
||||
}
|
||||
|
||||
// IsAccountDisabledError Checks if error is an account disabled error | 检查是否为账号禁用错误
|
||||
func IsAccountDisabledError(err error) bool {
|
||||
return errors.Is(err, ErrAccountDisabled)
|
||||
}
|
||||
|
||||
// IsTokenError Checks if error is a token-related error | 检查是否为Token相关错误
|
||||
func IsTokenError(err error) bool {
|
||||
return errors.Is(err, ErrTokenInvalid) || errors.Is(err, ErrTokenExpired)
|
||||
}
|
||||
|
||||
// GetErrorCode Extracts error code from SaTokenError | 从SaTokenError中提取错误码
|
||||
func GetErrorCode(err error) int {
|
||||
var saErr *SaTokenError
|
||||
if errors.As(err, &saErr) {
|
||||
return saErr.Code
|
||||
}
|
||||
return CodeServerError
|
||||
}
|
||||
|
||||
// ============ Error Code Definitions | 错误码定义 ============
|
||||
|
||||
const (
|
||||
// Standard HTTP status codes | 标准 HTTP 状态码
|
||||
CodeSuccess = 200 // Request successful | 请求成功
|
||||
|
||||
+146
-15
@@ -46,12 +46,12 @@ const (
|
||||
|
||||
// EventData contains information about a triggered event | 事件数据,包含触发事件的相关信息
|
||||
type EventData struct {
|
||||
Event Event // Event type | 事件类型
|
||||
LoginID string // User login ID | 用户登录ID
|
||||
Device string // Device identifier | 设备标识
|
||||
Token string // Authentication token | 认证Token
|
||||
Extra map[string]interface{} // Additional custom data | 额外的自定义数据
|
||||
Timestamp int64 // Unix timestamp when event was triggered | 事件触发的Unix时间戳
|
||||
Event Event // Event type | 事件类型
|
||||
LoginID string // User login ID | 用户登录ID
|
||||
Device string // Device identifier | 设备标识
|
||||
Token string // Authentication token | 认证Token
|
||||
Extra map[string]any // Additional custom data | 额外的自定义数据
|
||||
Timestamp int64 // Unix timestamp when event was triggered | 事件触发的Unix时间戳
|
||||
}
|
||||
|
||||
// String returns a string representation of the event data | 返回事件数据的字符串表示
|
||||
@@ -87,35 +87,106 @@ type listenerEntry struct {
|
||||
config ListenerConfig
|
||||
}
|
||||
|
||||
// EventFilter is a function that decides whether an event should be processed | 事件过滤器,决定事件是否应该被处理
|
||||
type EventFilter func(data *EventData) bool
|
||||
|
||||
// EventStats contains statistics about event processing | 事件统计信息
|
||||
type EventStats struct {
|
||||
TotalTriggered int64 // Total number of events triggered | 触发的事件总数
|
||||
EventCounts map[Event]int64 // Count per event type | 各类型事件的计数
|
||||
LastTriggered map[Event]time.Time // Last trigger time per event | 各类型事件的最后触发时间
|
||||
}
|
||||
|
||||
// Manager manages event listeners and dispatches events | 事件管理器,管理监听器并分发事件
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
listeners map[Event][]listenerEntry
|
||||
panicHandler func(event Event, data *EventData, recovered interface{})
|
||||
panicHandler func(event Event, data *EventData, recovered any)
|
||||
listenerCounter int
|
||||
enabledEvents map[Event]bool // If nil, all events are enabled | 如果为nil,所有事件都启用
|
||||
asyncWaitGroup sync.WaitGroup // For waiting on async listeners during shutdown | 用于等待异步监听器完成
|
||||
filters []EventFilter // Global event filters | 全局事件过滤器
|
||||
stats *EventStats // Event statistics | 事件统计
|
||||
enableStats bool // Whether to collect statistics | 是否收集统计信息
|
||||
}
|
||||
|
||||
// NewManager creates a new event manager | 创建新的事件管理器
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
listeners: make(map[Event][]listenerEntry),
|
||||
panicHandler: func(event Event, data *EventData, recovered interface{}) {
|
||||
panicHandler: func(event Event, data *EventData, recovered any) {
|
||||
// Default panic handler: log but don't crash | 默认panic处理器:记录日志但不崩溃
|
||||
fmt.Printf("sa-token: listener panic recovered: event=%s, panic=%v\n", event, recovered)
|
||||
},
|
||||
enabledEvents: nil, // All events enabled by default | 默认启用所有事件
|
||||
filters: make([]EventFilter, 0),
|
||||
stats: &EventStats{
|
||||
EventCounts: make(map[Event]int64),
|
||||
LastTriggered: make(map[Event]time.Time),
|
||||
},
|
||||
enableStats: false, // Stats disabled by default | 默认不启用统计
|
||||
}
|
||||
}
|
||||
|
||||
// SetPanicHandler sets a custom panic handler for listener errors | 设置自定义的panic处理器
|
||||
func (m *Manager) SetPanicHandler(handler func(event Event, data *EventData, recovered interface{})) {
|
||||
func (m *Manager) SetPanicHandler(handler func(event Event, data *EventData, recovered any)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.panicHandler = handler
|
||||
}
|
||||
|
||||
// AddFilter adds a global event filter | 添加全局事件过滤器
|
||||
func (m *Manager) AddFilter(filter EventFilter) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.filters = append(m.filters, filter)
|
||||
}
|
||||
|
||||
// ClearFilters removes all event filters | 清除所有事件过滤器
|
||||
func (m *Manager) ClearFilters() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.filters = make([]EventFilter, 0)
|
||||
}
|
||||
|
||||
// EnableStats enables event statistics collection | 启用事件统计
|
||||
func (m *Manager) EnableStats(enable bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.enableStats = enable
|
||||
}
|
||||
|
||||
// GetStats returns a copy of event statistics | 获取事件统计信息副本
|
||||
func (m *Manager) GetStats() EventStats {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
stats := EventStats{
|
||||
TotalTriggered: m.stats.TotalTriggered,
|
||||
EventCounts: make(map[Event]int64),
|
||||
LastTriggered: make(map[Event]time.Time),
|
||||
}
|
||||
|
||||
for event, count := range m.stats.EventCounts {
|
||||
stats.EventCounts[event] = count
|
||||
}
|
||||
for event, t := range m.stats.LastTriggered {
|
||||
stats.LastTriggered[event] = t
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// ResetStats resets event statistics | 重置事件统计
|
||||
func (m *Manager) ResetStats() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.stats = &EventStats{
|
||||
EventCounts: make(map[Event]int64),
|
||||
LastTriggered: make(map[Event]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// EnableEvent enables specific events (disables all others) | 启用特定事件(禁用其他所有事件)
|
||||
// Call with no arguments to enable all events | 不传参数时启用所有事件
|
||||
func (m *Manager) EnableEvent(events ...Event) {
|
||||
@@ -230,13 +301,15 @@ func (m *Manager) Unregister(listenerID string) bool {
|
||||
// sortListeners sorts listeners by priority (descending)
|
||||
func (m *Manager) sortListeners(event Event) {
|
||||
entries := m.listeners[event]
|
||||
// Simple bubble sort (listeners count is usually small)
|
||||
for i := 0; i < len(entries)-1; i++ {
|
||||
for j := 0; j < len(entries)-i-1; j++ {
|
||||
if entries[j].config.Priority < entries[j+1].config.Priority {
|
||||
entries[j], entries[j+1] = entries[j+1], entries[j]
|
||||
}
|
||||
// Use insertion sort (efficient for small lists and maintains stability)
|
||||
for i := 1; i < len(entries); i++ {
|
||||
key := entries[i]
|
||||
j := i - 1
|
||||
for j >= 0 && entries[j].config.Priority < key.config.Priority {
|
||||
entries[j+1] = entries[j]
|
||||
j--
|
||||
}
|
||||
entries[j+1] = key
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,6 +328,21 @@ func (m *Manager) Trigger(data *EventData) {
|
||||
data.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
for _, filter := range m.filters {
|
||||
if !filter(data) {
|
||||
m.mu.RUnlock()
|
||||
return // Event filtered out
|
||||
}
|
||||
}
|
||||
|
||||
// Update statistics
|
||||
if m.enableStats {
|
||||
m.stats.TotalTriggered++
|
||||
m.stats.EventCounts[data.Event]++
|
||||
m.stats.LastTriggered[data.Event] = time.Now()
|
||||
}
|
||||
|
||||
// Collect listeners to call
|
||||
var listenersToCall []listenerEntry
|
||||
|
||||
@@ -281,6 +369,17 @@ func (m *Manager) Trigger(data *EventData) {
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回
|
||||
func (m *Manager) TriggerAsync(data *EventData) {
|
||||
go m.Trigger(data)
|
||||
}
|
||||
|
||||
// TriggerSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成
|
||||
func (m *Manager) TriggerSync(data *EventData) {
|
||||
m.Trigger(data)
|
||||
m.Wait()
|
||||
}
|
||||
|
||||
// safeCall executes a listener with panic recovery
|
||||
func (m *Manager) safeCall(listener Listener, data *EventData, wg *sync.WaitGroup) {
|
||||
if wg != nil {
|
||||
@@ -339,3 +438,35 @@ func (m *Manager) CountForEvent(event Event) int {
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.listeners[event])
|
||||
}
|
||||
|
||||
// GetListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID
|
||||
func (m *Manager) GetListenerIDs(event Event) []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
entries := m.listeners[event]
|
||||
ids := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
ids = append(ids, entry.config.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetAllEvents returns all events that have registered listeners | 获取所有已注册监听器的事件
|
||||
func (m *Manager) GetAllEvents() []Event {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
events := make([]Event, 0, len(m.listeners))
|
||||
for event := range m.listeners {
|
||||
events = append(events, event)
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
// HasListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器
|
||||
func (m *Manager) HasListeners(event Event) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.listeners[event]) > 0
|
||||
}
|
||||
|
||||
+226
-154
@@ -14,16 +14,48 @@ import (
|
||||
"github.com/click33/sa-token-go/core/token"
|
||||
)
|
||||
|
||||
// TokenInfo Token信息
|
||||
// Constants for storage keys and default values | 存储键和默认值常量
|
||||
const (
|
||||
DefaultDevice = "default"
|
||||
DefaultPrefix = "satoken:"
|
||||
DisableValue = "1"
|
||||
DefaultNonceTTL = 5 * time.Minute
|
||||
|
||||
// Key prefixes | 键前缀
|
||||
TokenKeyPrefix = "token:"
|
||||
AccountKeyPrefix = "account:"
|
||||
DisableKeyPrefix = "disable:"
|
||||
|
||||
// Session keys | Session键
|
||||
SessionKeyLoginID = "loginId"
|
||||
SessionKeyDevice = "device"
|
||||
SessionKeyLoginTime = "loginTime"
|
||||
SessionKeyPermissions = "permissions"
|
||||
SessionKeyRoles = "roles"
|
||||
|
||||
// Wildcard for permissions | 权限通配符
|
||||
PermissionWildcard = "*"
|
||||
PermissionSeparator = ":"
|
||||
)
|
||||
|
||||
// Error variables | 错误变量
|
||||
var (
|
||||
ErrAccountDisabled = fmt.Errorf("account is disabled")
|
||||
ErrNotLogin = fmt.Errorf("not login")
|
||||
ErrTokenNotFound = fmt.Errorf("token not found")
|
||||
ErrInvalidTokenData = fmt.Errorf("invalid token data")
|
||||
)
|
||||
|
||||
// TokenInfo Token information | Token信息
|
||||
type TokenInfo struct {
|
||||
LoginID string `json:"loginId"`
|
||||
Device string `json:"device"`
|
||||
CreateTime int64 `json:"createTime"`
|
||||
ActiveTime int64 `json:"activeTime"` // 最后活跃时间
|
||||
ActiveTime int64 `json:"activeTime"` // Last active time | 最后活跃时间
|
||||
Tag string `json:"tag,omitempty"`
|
||||
}
|
||||
|
||||
// Manager 认证管理器
|
||||
// Manager Authentication manager | 认证管理器
|
||||
type Manager struct {
|
||||
storage adapter.Storage
|
||||
config *config.Config
|
||||
@@ -34,7 +66,7 @@ type Manager struct {
|
||||
oauth2Server *oauth2.OAuth2Server
|
||||
}
|
||||
|
||||
// NewManager 创建管理器
|
||||
// NewManager Creates a new manager | 创建管理器
|
||||
func NewManager(storage adapter.Storage, cfg *config.Config) *Manager {
|
||||
if cfg == nil {
|
||||
cfg = config.DefaultConfig()
|
||||
@@ -44,46 +76,63 @@ func NewManager(storage adapter.Storage, cfg *config.Config) *Manager {
|
||||
storage: storage,
|
||||
config: cfg,
|
||||
generator: token.NewGenerator(cfg),
|
||||
prefix: "satoken:",
|
||||
nonceManager: security.NewNonceManager(storage, 5*time.Minute),
|
||||
prefix: DefaultPrefix,
|
||||
nonceManager: security.NewNonceManager(storage, DefaultNonceTTL),
|
||||
refreshManager: security.NewRefreshTokenManager(storage, cfg),
|
||||
oauth2Server: oauth2.NewOAuth2Server(storage),
|
||||
}
|
||||
}
|
||||
|
||||
// ============ 登录认证 ============
|
||||
// ============ Helper Methods | 辅助方法 ============
|
||||
|
||||
// Login 登录,返回Token
|
||||
// getDevice extracts device type from optional parameter | 从可选参数中提取设备类型
|
||||
func getDevice(device []string) string {
|
||||
if len(device) > 0 && device[0] != "" {
|
||||
return device[0]
|
||||
}
|
||||
return DefaultDevice
|
||||
}
|
||||
|
||||
// getExpiration calculates expiration duration from config | 从配置计算过期时间
|
||||
func (m *Manager) getExpiration() time.Duration {
|
||||
if m.config.Timeout > 0 {
|
||||
return time.Duration(m.config.Timeout) * time.Second
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// assertString safely converts interface to string | 安全地将interface转换为string
|
||||
func assertString(v any) (string, bool) {
|
||||
s, ok := v.(string)
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// ============ Login Authentication | 登录认证 ============
|
||||
|
||||
// Login Performs user login and returns token | 登录,返回Token
|
||||
func (m *Manager) Login(loginID string, device ...string) (string, error) {
|
||||
deviceType := "default"
|
||||
if len(device) > 0 {
|
||||
deviceType = device[0]
|
||||
}
|
||||
deviceType := getDevice(device)
|
||||
|
||||
// 检查是否被封禁
|
||||
// Check if account is disabled | 检查是否被封禁
|
||||
if m.IsDisable(loginID) {
|
||||
return "", fmt.Errorf("account is disabled")
|
||||
return "", ErrAccountDisabled
|
||||
}
|
||||
|
||||
// 如果不允许并发登录,先踢掉旧的
|
||||
// Kick out old session if concurrent login is not allowed | 如果不允许并发登录,先踢掉旧的
|
||||
if !m.config.IsConcurrent {
|
||||
m.kickout(loginID, deviceType)
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
// Generate token | 生成Token
|
||||
tokenValue, err := m.generator.Generate(loginID, deviceType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
// 计算过期时间
|
||||
var expiration time.Duration
|
||||
if m.config.Timeout > 0 {
|
||||
expiration = time.Duration(m.config.Timeout) * time.Second
|
||||
}
|
||||
|
||||
expiration := m.getExpiration()
|
||||
now := time.Now().Unix()
|
||||
// 保存Token信息
|
||||
|
||||
// Save token info | 保存Token信息
|
||||
tokenInfo := &TokenInfo{
|
||||
LoginID: loginID,
|
||||
Device: deviceType,
|
||||
@@ -95,34 +144,27 @@ func (m *Manager) Login(loginID string, device ...string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 保存账号-Token映射
|
||||
// Save account-token mapping | 保存账号-Token映射
|
||||
accountKey := m.getAccountKey(loginID, deviceType)
|
||||
if err := m.storage.Set(accountKey, tokenValue, expiration); err != nil {
|
||||
return "", fmt.Errorf("failed to save account mapping: %w", err)
|
||||
}
|
||||
|
||||
// 创建Session
|
||||
// Create session | 创建Session
|
||||
sess := session.NewSession(loginID, m.storage, m.prefix)
|
||||
sess.Set("loginId", loginID)
|
||||
sess.Set("device", deviceType)
|
||||
sess.Set("loginTime", now)
|
||||
sess.Set(SessionKeyLoginID, loginID)
|
||||
sess.Set(SessionKeyDevice, deviceType)
|
||||
sess.Set(SessionKeyLoginTime, now)
|
||||
|
||||
return tokenValue, nil
|
||||
}
|
||||
|
||||
// LoginByToken 使用指定Token登录(用于token无感刷新)
|
||||
// LoginByToken Login with specified token (for seamless token refresh) | 使用指定Token登录(用于token无感刷新)
|
||||
func (m *Manager) LoginByToken(loginID string, tokenValue string, device ...string) error {
|
||||
deviceType := "default"
|
||||
if len(device) > 0 {
|
||||
deviceType = device[0]
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if m.config.Timeout > 0 {
|
||||
expiration = time.Duration(m.config.Timeout) * time.Second
|
||||
}
|
||||
|
||||
deviceType := getDevice(device)
|
||||
expiration := m.getExpiration()
|
||||
now := time.Now().Unix()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
LoginID: loginID,
|
||||
Device: deviceType,
|
||||
@@ -138,37 +180,41 @@ func (m *Manager) LoginByToken(loginID string, tokenValue string, device ...stri
|
||||
return m.storage.Set(accountKey, tokenValue, expiration)
|
||||
}
|
||||
|
||||
// Logout 登出
|
||||
// Logout Performs user logout | 登出
|
||||
func (m *Manager) Logout(loginID string, device ...string) error {
|
||||
deviceType := "default"
|
||||
if len(device) > 0 {
|
||||
deviceType = device[0]
|
||||
}
|
||||
|
||||
deviceType := getDevice(device)
|
||||
accountKey := m.getAccountKey(loginID, deviceType)
|
||||
|
||||
tokenValue, err := m.storage.Get(accountKey)
|
||||
if err != nil || tokenValue == nil {
|
||||
return nil // 已经登出
|
||||
return nil // Already logged out | 已经登出
|
||||
}
|
||||
|
||||
// 删除Token
|
||||
tokenKey := m.getTokenKey(tokenValue.(string))
|
||||
// Delete token | 删除Token
|
||||
tokenStr, ok := assertString(tokenValue)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenKey := m.getTokenKey(tokenStr)
|
||||
m.storage.Delete(tokenKey)
|
||||
|
||||
// 删除账号映射
|
||||
// Delete account mapping | 删除账号映射
|
||||
m.storage.Delete(accountKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutByToken 根据Token登出
|
||||
// LogoutByToken Logout by token | 根据Token登出
|
||||
func (m *Manager) LogoutByToken(tokenValue string) error {
|
||||
if tokenValue == "" {
|
||||
return nil
|
||||
}
|
||||
tokenKey := m.getTokenKey(tokenValue)
|
||||
m.storage.Delete(tokenKey)
|
||||
return nil
|
||||
return m.storage.Delete(tokenKey)
|
||||
}
|
||||
|
||||
// kickout 踢人下线
|
||||
// kickout Kick user offline (private) | 踢人下线(私有)
|
||||
func (m *Manager) kickout(loginID string, device string) error {
|
||||
accountKey := m.getAccountKey(loginID, device)
|
||||
tokenValue, err := m.storage.Get(accountKey)
|
||||
@@ -176,34 +222,35 @@ func (m *Manager) kickout(loginID string, device string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenKey := m.getTokenKey(tokenValue.(string))
|
||||
tokenStr, ok := assertString(tokenValue)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenKey := m.getTokenKey(tokenStr)
|
||||
return m.storage.Delete(tokenKey)
|
||||
}
|
||||
|
||||
// Kickout 踢人下线(公开方法)
|
||||
// Kickout Kick user offline (public method) | 踢人下线(公开方法)
|
||||
func (m *Manager) Kickout(loginID string, device ...string) error {
|
||||
deviceType := "default"
|
||||
if len(device) > 0 {
|
||||
deviceType = device[0]
|
||||
}
|
||||
deviceType := getDevice(device)
|
||||
return m.kickout(loginID, deviceType)
|
||||
}
|
||||
|
||||
// ============ Token验证 ============
|
||||
// ============ Token Validation | Token验证 ============
|
||||
|
||||
// IsLogin 检查是否登录
|
||||
// IsLogin Checks if user is logged in | 检查是否登录
|
||||
func (m *Manager) IsLogin(tokenValue string) bool {
|
||||
if tokenValue == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
tokenKey := m.getTokenKey(tokenValue)
|
||||
exists := m.storage.Exists(tokenKey)
|
||||
if !exists {
|
||||
if !m.storage.Exists(tokenKey) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 更新活跃时间并检查活跃超时
|
||||
// Check and update active timeout | 更新活跃时间并检查活跃超时
|
||||
if m.config.ActiveTimeout > 0 {
|
||||
info, _ := m.getTokenInfo(tokenValue)
|
||||
if info != nil {
|
||||
@@ -215,38 +262,41 @@ func (m *Manager) IsLogin(tokenValue string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// ✨ 异步自动续期(提高性能)
|
||||
// Async auto-renew for better performance | 异步自动续期(提高性能)
|
||||
if m.config.AutoRenew && m.config.Timeout > 0 {
|
||||
go func() {
|
||||
expiration := time.Duration(m.config.Timeout) * time.Second
|
||||
|
||||
// 延长Token存储的过期时间
|
||||
m.storage.Expire(tokenKey, expiration)
|
||||
|
||||
// 更新活跃时间
|
||||
info, _ := m.getTokenInfo(tokenValue)
|
||||
if info != nil {
|
||||
info.ActiveTime = time.Now().Unix()
|
||||
m.saveTokenInfo(tokenValue, info, expiration)
|
||||
}
|
||||
}()
|
||||
go m.renewToken(tokenValue, tokenKey)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckLogin 检查登录(未登录抛出错误)
|
||||
// renewToken Renews token expiration asynchronously | 异步续期Token
|
||||
func (m *Manager) renewToken(tokenValue, tokenKey string) {
|
||||
expiration := m.getExpiration()
|
||||
|
||||
// Extend token storage expiration | 延长Token存储的过期时间
|
||||
m.storage.Expire(tokenKey, expiration)
|
||||
|
||||
// Update active time | 更新活跃时间
|
||||
info, _ := m.getTokenInfo(tokenValue)
|
||||
if info != nil {
|
||||
info.ActiveTime = time.Now().Unix()
|
||||
m.saveTokenInfo(tokenValue, info, expiration)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckLogin Checks login status (throws error if not logged in) | 检查登录(未登录抛出错误)
|
||||
func (m *Manager) CheckLogin(tokenValue string) error {
|
||||
if !m.IsLogin(tokenValue) {
|
||||
return fmt.Errorf("not login")
|
||||
return ErrNotLogin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginID 根据Token获取登录ID
|
||||
// GetLoginID Gets login ID from token | 根据Token获取登录ID
|
||||
func (m *Manager) GetLoginID(tokenValue string) (string, error) {
|
||||
if !m.IsLogin(tokenValue) {
|
||||
return "", fmt.Errorf("not login")
|
||||
return "", ErrNotLogin
|
||||
}
|
||||
|
||||
info, err := m.getTokenInfo(tokenValue)
|
||||
@@ -257,7 +307,7 @@ func (m *Manager) GetLoginID(tokenValue string) (string, error) {
|
||||
return info.LoginID, nil
|
||||
}
|
||||
|
||||
// GetLoginIDNotCheck 获取登录ID(不检查Token是否有效)
|
||||
// GetLoginIDNotCheck Gets login ID without checking token validity | 获取登录ID(不检查Token是否有效)
|
||||
func (m *Manager) GetLoginIDNotCheck(tokenValue string) (string, error) {
|
||||
info, err := m.getTokenInfo(tokenValue)
|
||||
if err != nil {
|
||||
@@ -266,50 +316,52 @@ func (m *Manager) GetLoginIDNotCheck(tokenValue string) (string, error) {
|
||||
return info.LoginID, nil
|
||||
}
|
||||
|
||||
// GetTokenValue 根据登录ID获取Token
|
||||
// GetTokenValue Gets token by login ID | 根据登录ID获取Token
|
||||
func (m *Manager) GetTokenValue(loginID string, device ...string) (string, error) {
|
||||
deviceType := "default"
|
||||
if len(device) > 0 {
|
||||
deviceType = device[0]
|
||||
}
|
||||
|
||||
deviceType := getDevice(device)
|
||||
accountKey := m.getAccountKey(loginID, deviceType)
|
||||
|
||||
tokenValue, err := m.storage.Get(accountKey)
|
||||
if err != nil || tokenValue == nil {
|
||||
return "", fmt.Errorf("token not found for login id: %s", loginID)
|
||||
}
|
||||
|
||||
return tokenValue.(string), nil
|
||||
tokenStr, ok := assertString(tokenValue)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid token value type")
|
||||
}
|
||||
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
||||
// GetTokenInfo 获取Token信息
|
||||
// GetTokenInfo Gets token information | 获取Token信息
|
||||
func (m *Manager) GetTokenInfo(tokenValue string) (*TokenInfo, error) {
|
||||
return m.getTokenInfo(tokenValue)
|
||||
}
|
||||
|
||||
// ============ 账号封禁 ============
|
||||
// ============ Account Disable | 账号封禁 ============
|
||||
|
||||
// Disable 封禁账号
|
||||
// Disable Disables an account | 封禁账号
|
||||
func (m *Manager) Disable(loginID string, duration time.Duration) error {
|
||||
key := m.prefix + "disable:" + loginID
|
||||
return m.storage.Set(key, "1", duration)
|
||||
key := m.getDisableKey(loginID)
|
||||
return m.storage.Set(key, DisableValue, duration)
|
||||
}
|
||||
|
||||
// Untie 解封账号
|
||||
// Untie Re-enables a disabled account | 解封账号
|
||||
func (m *Manager) Untie(loginID string) error {
|
||||
key := m.prefix + "disable:" + loginID
|
||||
key := m.getDisableKey(loginID)
|
||||
return m.storage.Delete(key)
|
||||
}
|
||||
|
||||
// IsDisable 检查账号是否被封禁
|
||||
// IsDisable Checks if account is disabled | 检查账号是否被封禁
|
||||
func (m *Manager) IsDisable(loginID string) bool {
|
||||
key := m.prefix + "disable:" + loginID
|
||||
key := m.getDisableKey(loginID)
|
||||
return m.storage.Exists(key)
|
||||
}
|
||||
|
||||
// GetDisableTime 获取账号剩余封禁时间(秒)
|
||||
// GetDisableTime Gets remaining disable time in seconds | 获取账号剩余封禁时间(秒)
|
||||
func (m *Manager) GetDisableTime(loginID string) (int64, error) {
|
||||
key := m.prefix + "disable:" + loginID
|
||||
key := m.getDisableKey(loginID)
|
||||
ttl, err := m.storage.TTL(key)
|
||||
if err != nil {
|
||||
return -2, err
|
||||
@@ -317,9 +369,14 @@ func (m *Manager) GetDisableTime(loginID string) (int64, error) {
|
||||
return int64(ttl.Seconds()), nil
|
||||
}
|
||||
|
||||
// ============ Session管理 ============
|
||||
// getDisableKey Gets disable storage key | 获取禁用存储键
|
||||
func (m *Manager) getDisableKey(loginID string) string {
|
||||
return m.prefix + DisableKeyPrefix + loginID
|
||||
}
|
||||
|
||||
// GetSession 获取Session
|
||||
// ============ Session Management | Session管理 ============
|
||||
|
||||
// GetSession Gets session by login ID | 获取Session
|
||||
func (m *Manager) GetSession(loginID string) (*session.Session, error) {
|
||||
sess, err := session.Load(loginID, m.storage, m.prefix)
|
||||
if err != nil {
|
||||
@@ -328,7 +385,7 @@ func (m *Manager) GetSession(loginID string) (*session.Session, error) {
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// GetSessionByToken 根据Token获取Session
|
||||
// GetSessionByToken Gets session by token | 根据Token获取Session
|
||||
func (m *Manager) GetSessionByToken(tokenValue string) (*session.Session, error) {
|
||||
loginID, err := m.GetLoginID(tokenValue)
|
||||
if err != nil {
|
||||
@@ -337,7 +394,7 @@ func (m *Manager) GetSessionByToken(tokenValue string) (*session.Session, error)
|
||||
return m.GetSession(loginID)
|
||||
}
|
||||
|
||||
// DeleteSession 删除Session
|
||||
// DeleteSession Deletes session | 删除Session
|
||||
func (m *Manager) DeleteSession(loginID string) error {
|
||||
sess, err := m.GetSession(loginID)
|
||||
if err != nil {
|
||||
@@ -346,25 +403,25 @@ func (m *Manager) DeleteSession(loginID string) error {
|
||||
return sess.Destroy()
|
||||
}
|
||||
|
||||
// ============ 权限验证 ============
|
||||
// ============ Permission Validation | 权限验证 ============
|
||||
|
||||
// SetPermissions 设置权限
|
||||
// SetPermissions Sets permissions for user | 设置权限
|
||||
func (m *Manager) SetPermissions(loginID string, permissions []string) error {
|
||||
sess, err := m.GetSession(loginID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sess.Set("permissions", permissions)
|
||||
return sess.Set(SessionKeyPermissions, permissions)
|
||||
}
|
||||
|
||||
// GetPermissions 获取权限列表
|
||||
// GetPermissions Gets permission list | 获取权限列表
|
||||
func (m *Manager) GetPermissions(loginID string) ([]string, error) {
|
||||
sess, err := m.GetSession(loginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
perms, exists := sess.Get("permissions")
|
||||
perms, exists := sess.Get(SessionKeyPermissions)
|
||||
if !exists {
|
||||
return []string{}, nil
|
||||
}
|
||||
@@ -408,27 +465,29 @@ func (m *Manager) HasPermissionsOr(loginID string, permissions []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchPermission 权限匹配(支持通配符)
|
||||
// matchPermission Matches permission with wildcards support | 权限匹配(支持通配符)
|
||||
func (m *Manager) matchPermission(pattern, permission string) bool {
|
||||
if pattern == "*" || pattern == permission {
|
||||
// Exact match or wildcard | 精确匹配或通配符
|
||||
if pattern == PermissionWildcard || pattern == permission {
|
||||
return true
|
||||
}
|
||||
|
||||
// 支持通配符,例如 user:* 匹配 user:add, user:delete等
|
||||
if strings.HasSuffix(pattern, ":*") {
|
||||
prefix := strings.TrimSuffix(pattern, "*")
|
||||
// Pattern like "user:*" matches "user:add", "user:delete", etc. | 支持通配符,例如 user:* 匹配 user:add, user:delete等
|
||||
wildcardSuffix := PermissionSeparator + PermissionWildcard
|
||||
if strings.HasSuffix(pattern, wildcardSuffix) {
|
||||
prefix := strings.TrimSuffix(pattern, PermissionWildcard)
|
||||
return strings.HasPrefix(permission, prefix)
|
||||
}
|
||||
|
||||
// 支持 user:*:view 这样的模式
|
||||
if strings.Contains(pattern, "*") {
|
||||
parts := strings.Split(pattern, ":")
|
||||
permParts := strings.Split(permission, ":")
|
||||
// Pattern like "user:*:view" | 支持 user:*:view 这样的模式
|
||||
if strings.Contains(pattern, PermissionWildcard) {
|
||||
parts := strings.Split(pattern, PermissionSeparator)
|
||||
permParts := strings.Split(permission, PermissionSeparator)
|
||||
if len(parts) != len(permParts) {
|
||||
return false
|
||||
}
|
||||
for i, part := range parts {
|
||||
if part != "*" && part != permParts[i] {
|
||||
if part != PermissionWildcard && part != permParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -438,25 +497,25 @@ func (m *Manager) matchPermission(pattern, permission string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ============ 角色验证 ============
|
||||
// ============ Role Validation | 角色验证 ============
|
||||
|
||||
// SetRoles 设置角色
|
||||
// SetRoles Sets roles for user | 设置角色
|
||||
func (m *Manager) SetRoles(loginID string, roles []string) error {
|
||||
sess, err := m.GetSession(loginID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sess.Set("roles", roles)
|
||||
return sess.Set(SessionKeyRoles, roles)
|
||||
}
|
||||
|
||||
// GetRoles 获取角色列表
|
||||
// GetRoles Gets role list | 获取角色列表
|
||||
func (m *Manager) GetRoles(loginID string) ([]string, error) {
|
||||
sess, err := m.GetSession(loginID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles, exists := sess.Get("roles")
|
||||
roles, exists := sess.Get(SessionKeyRoles)
|
||||
if !exists {
|
||||
return []string{}, nil
|
||||
}
|
||||
@@ -499,9 +558,9 @@ func (m *Manager) HasRolesOr(loginID string, roles []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ============ Token标签 ============
|
||||
// ============ Token Tags | Token标签 ============
|
||||
|
||||
// SetTokenTag 设置Token标签
|
||||
// SetTokenTag Sets token tag | 设置Token标签
|
||||
func (m *Manager) SetTokenTag(tokenValue, tag string) error {
|
||||
info, err := m.getTokenInfo(tokenValue)
|
||||
if err != nil {
|
||||
@@ -509,16 +568,12 @@ func (m *Manager) SetTokenTag(tokenValue, tag string) error {
|
||||
}
|
||||
|
||||
info.Tag = tag
|
||||
|
||||
var expiration time.Duration
|
||||
if m.config.Timeout > 0 {
|
||||
expiration = time.Duration(m.config.Timeout) * time.Second
|
||||
}
|
||||
expiration := m.getExpiration()
|
||||
|
||||
return m.saveTokenInfo(tokenValue, info, expiration)
|
||||
}
|
||||
|
||||
// GetTokenTag 获取Token标签
|
||||
// GetTokenTag Gets token tag | 获取Token标签
|
||||
func (m *Manager) GetTokenTag(tokenValue string) (string, error) {
|
||||
info, err := m.getTokenInfo(tokenValue)
|
||||
if err != nil {
|
||||
@@ -527,28 +582,30 @@ func (m *Manager) GetTokenTag(tokenValue string) (string, error) {
|
||||
return info.Tag, nil
|
||||
}
|
||||
|
||||
// ============ 会话查询 ============
|
||||
// ============ Session Query | 会话查询 ============
|
||||
|
||||
// GetTokenValueListByLoginID 获取指定账号的所有Token
|
||||
// GetTokenValueListByLoginID Gets all tokens for specified account | 获取指定账号的所有Token
|
||||
func (m *Manager) GetTokenValueListByLoginID(loginID string) ([]string, error) {
|
||||
pattern := m.prefix + "account:" + loginID + ":*"
|
||||
pattern := m.prefix + AccountKeyPrefix + loginID + ":*"
|
||||
keys, err := m.storage.Keys(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokens := make([]string, 0)
|
||||
tokens := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
value, err := m.storage.Get(key)
|
||||
if err == nil && value != nil {
|
||||
tokens = append(tokens, value.(string))
|
||||
if tokenStr, ok := assertString(value); ok {
|
||||
tokens = append(tokens, tokenStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// GetSessionCountByLoginID 获取指定账号的Session数量
|
||||
// GetSessionCountByLoginID Gets session count for specified account | 获取指定账号的Session数量
|
||||
func (m *Manager) GetSessionCountByLoginID(loginID string) (int, error) {
|
||||
tokens, err := m.GetTokenValueListByLoginID(loginID)
|
||||
if err != nil {
|
||||
@@ -557,19 +614,19 @@ func (m *Manager) GetSessionCountByLoginID(loginID string) (int, error) {
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
// ============ 辅助方法 ============
|
||||
// ============ Internal Helper Methods | 内部辅助方法 ============
|
||||
|
||||
// getTokenKey 获取Token存储键
|
||||
// getTokenKey Gets token storage key | 获取Token存储键
|
||||
func (m *Manager) getTokenKey(tokenValue string) string {
|
||||
return m.prefix + "token:" + tokenValue
|
||||
return m.prefix + TokenKeyPrefix + tokenValue
|
||||
}
|
||||
|
||||
// getAccountKey 获取账号存储键
|
||||
// getAccountKey Gets account storage key | 获取账号存储键
|
||||
func (m *Manager) getAccountKey(loginID, device string) string {
|
||||
return m.prefix + "account:" + loginID + ":" + device
|
||||
return m.prefix + AccountKeyPrefix + loginID + PermissionSeparator + device
|
||||
}
|
||||
|
||||
// saveTokenInfo 保存Token信息
|
||||
// saveTokenInfo Saves token information | 保存Token信息
|
||||
func (m *Manager) saveTokenInfo(tokenValue string, info *TokenInfo, expiration time.Duration) error {
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
@@ -580,28 +637,33 @@ func (m *Manager) saveTokenInfo(tokenValue string, info *TokenInfo, expiration t
|
||||
return m.storage.Set(tokenKey, string(data), expiration)
|
||||
}
|
||||
|
||||
// getTokenInfo 获取Token信息
|
||||
// getTokenInfo Gets token information | 获取Token信息
|
||||
func (m *Manager) getTokenInfo(tokenValue string) (*TokenInfo, error) {
|
||||
tokenKey := m.getTokenKey(tokenValue)
|
||||
data, err := m.storage.Get(tokenKey)
|
||||
if err != nil || data == nil {
|
||||
return nil, fmt.Errorf("token not found")
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
|
||||
dataStr, ok := assertString(data)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTokenData
|
||||
}
|
||||
|
||||
var info TokenInfo
|
||||
if err := json.Unmarshal([]byte(data.(string)), &info); err != nil {
|
||||
return nil, fmt.Errorf("invalid token data: %w", err)
|
||||
if err := json.Unmarshal([]byte(dataStr), &info); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidTokenData, err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// toStringSlice 将interface{}转换为[]string
|
||||
func (m *Manager) toStringSlice(v interface{}) []string {
|
||||
// toStringSlice Converts any to []string | 将any转换为[]string
|
||||
func (m *Manager) toStringSlice(v any) []string {
|
||||
switch val := v.(type) {
|
||||
case []string:
|
||||
return val
|
||||
case []interface{}:
|
||||
case []any:
|
||||
result := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
if str, ok := item.(string); ok {
|
||||
@@ -614,36 +676,46 @@ func (m *Manager) toStringSlice(v interface{}) []string {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig 获取配置
|
||||
// ============ Public Getters | 公共获取器 ============
|
||||
|
||||
// GetConfig Gets configuration | 获取配置
|
||||
func (m *Manager) GetConfig() *config.Config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// GetStorage 获取存储
|
||||
// GetStorage Gets storage | 获取存储
|
||||
func (m *Manager) GetStorage() adapter.Storage {
|
||||
return m.storage
|
||||
}
|
||||
|
||||
// ============ Security Features | 安全特性 ============
|
||||
|
||||
// GenerateNonce Generates a one-time nonce | 生成一次性随机数
|
||||
func (m *Manager) GenerateNonce() (string, error) {
|
||||
return m.nonceManager.Generate()
|
||||
}
|
||||
|
||||
// VerifyNonce Verifies a nonce | 验证随机数
|
||||
func (m *Manager) VerifyNonce(nonce string) bool {
|
||||
return m.nonceManager.Verify(nonce)
|
||||
}
|
||||
|
||||
// LoginWithRefreshToken Logs in with refresh token | 使用刷新令牌登录
|
||||
func (m *Manager) LoginWithRefreshToken(loginID, device string) (*security.RefreshTokenInfo, error) {
|
||||
return m.refreshManager.GenerateTokenPair(loginID, device)
|
||||
}
|
||||
|
||||
// RefreshAccessToken Refreshes access token | 刷新访问令牌
|
||||
func (m *Manager) RefreshAccessToken(refreshToken string) (*security.RefreshTokenInfo, error) {
|
||||
return m.refreshManager.RefreshAccessToken(refreshToken)
|
||||
}
|
||||
|
||||
// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌
|
||||
func (m *Manager) RevokeRefreshToken(refreshToken string) error {
|
||||
return m.refreshManager.RevokeRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例
|
||||
func (m *Manager) GetOAuth2Server() *oauth2.OAuth2Server {
|
||||
return m.oauth2Server
|
||||
}
|
||||
|
||||
+156
-58
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/click33/sa-token-go/core/adapter"
|
||||
@@ -20,11 +21,42 @@ import (
|
||||
// 5. RefreshAccessToken() - Use refresh token to get new token | 用刷新令牌获取新令牌
|
||||
//
|
||||
// Usage | 用法:
|
||||
// server := core.NewOAuth2Server(storage)
|
||||
// server.RegisterClient(&core.OAuth2Client{...})
|
||||
// server := oauth2.NewOAuth2Server(storage)
|
||||
// server.RegisterClient(&oauth2.Client{...})
|
||||
// authCode, _ := server.GenerateAuthorizationCode(...)
|
||||
// token, _ := server.ExchangeCodeForToken(...)
|
||||
|
||||
// Constants for OAuth2 | OAuth2常量
|
||||
const (
|
||||
DefaultCodeExpiration = 10 * time.Minute // Authorization code expiration | 授权码过期时间
|
||||
DefaultTokenExpiration = 2 * time.Hour // Access token expiration | 访问令牌过期时间
|
||||
DefaultRefreshTTL = 30 * 24 * time.Hour // Refresh token expiration | 刷新令牌过期时间
|
||||
|
||||
CodeLength = 32 // Authorization code byte length | 授权码字节长度
|
||||
AccessTokenLength = 32 // Access token byte length | 访问令牌字节长度
|
||||
RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度
|
||||
|
||||
CodeKeyPrefix = "satoken:oauth2:code:" // Code storage key prefix | 授权码存储键前缀
|
||||
TokenKeyPrefix = "satoken:oauth2:token:" // Token storage key prefix | 令牌存储键前缀
|
||||
RefreshKeyPrefix = "satoken:oauth2:refresh:" // Refresh storage key prefix | 刷新令牌存储键前缀
|
||||
|
||||
TokenTypeBearer = "Bearer" // Token type | 令牌类型
|
||||
)
|
||||
|
||||
// Error variables | 错误变量
|
||||
var (
|
||||
ErrClientNotFound = fmt.Errorf("client not found")
|
||||
ErrInvalidRedirectURI = fmt.Errorf("invalid redirect_uri")
|
||||
ErrInvalidClientCredentials = fmt.Errorf("invalid client credentials")
|
||||
ErrInvalidAuthCode = fmt.Errorf("invalid authorization code")
|
||||
ErrAuthCodeUsed = fmt.Errorf("authorization code already used")
|
||||
ErrAuthCodeExpired = fmt.Errorf("authorization code expired")
|
||||
ErrClientMismatch = fmt.Errorf("client mismatch")
|
||||
ErrRedirectURIMismatch = fmt.Errorf("redirect_uri mismatch")
|
||||
ErrInvalidAccessToken = fmt.Errorf("invalid access token")
|
||||
ErrInvalidTokenData = fmt.Errorf("invalid token data")
|
||||
)
|
||||
|
||||
// GrantType OAuth2 grant type | OAuth2授权类型
|
||||
type GrantType string
|
||||
|
||||
@@ -71,55 +103,74 @@ type AccessToken struct {
|
||||
type OAuth2Server struct {
|
||||
storage adapter.Storage
|
||||
clients map[string]*Client
|
||||
clientsMu sync.RWMutex // Clients map lock | 客户端映射锁
|
||||
codeExpiration time.Duration // Authorization code expiration (10min) | 授权码过期时间(10分钟)
|
||||
tokenExpiration time.Duration // Access token expiration (2h) | 访问令牌过期时间(2小时)
|
||||
}
|
||||
|
||||
// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器
|
||||
// NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器
|
||||
func NewOAuth2Server(storage adapter.Storage) *OAuth2Server {
|
||||
return &OAuth2Server{
|
||||
storage: storage,
|
||||
clients: make(map[string]*Client),
|
||||
codeExpiration: 10 * time.Minute, // Authorization code expires in 10 minutes | 授权码10分钟过期
|
||||
tokenExpiration: 2 * time.Hour, // Access token expires in 2 hours | 访问令牌2小时过期
|
||||
codeExpiration: DefaultCodeExpiration,
|
||||
tokenExpiration: DefaultTokenExpiration,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClient registers an OAuth2 client | 注册OAuth2客户端
|
||||
func (s *OAuth2Server) RegisterClient(client *Client) {
|
||||
// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端
|
||||
func (s *OAuth2Server) RegisterClient(client *Client) error {
|
||||
if client == nil || client.ClientID == "" {
|
||||
return fmt.Errorf("invalid client: clientID is required")
|
||||
}
|
||||
|
||||
s.clientsMu.Lock()
|
||||
defer s.clientsMu.Unlock()
|
||||
|
||||
s.clients[client.ClientID] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClient gets client by ID | 根据ID获取客户端
|
||||
// UnregisterClient Unregisters an OAuth2 client | 注销OAuth2客户端
|
||||
func (s *OAuth2Server) UnregisterClient(clientID string) {
|
||||
s.clientsMu.Lock()
|
||||
defer s.clientsMu.Unlock()
|
||||
|
||||
delete(s.clients, clientID)
|
||||
}
|
||||
|
||||
// GetClient Gets client by ID | 根据ID获取客户端
|
||||
func (s *OAuth2Server) GetClient(clientID string) (*Client, error) {
|
||||
s.clientsMu.RLock()
|
||||
defer s.clientsMu.RUnlock()
|
||||
|
||||
client, exists := s.clients[clientID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("client not found")
|
||||
return nil, ErrClientNotFound
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GenerateAuthorizationCode generates authorization code | 生成授权码
|
||||
// GenerateAuthorizationCode Generates authorization code | 生成授权码
|
||||
func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID string, scopes []string) (*AuthorizationCode, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("userID cannot be empty")
|
||||
}
|
||||
|
||||
client, err := s.GetClient(clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validRedirect := false
|
||||
for _, uri := range client.RedirectURIs {
|
||||
if uri == redirectURI {
|
||||
validRedirect = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validRedirect {
|
||||
return nil, fmt.Errorf("invalid redirect_uri")
|
||||
// Validate redirect URI | 验证回调URI
|
||||
if !s.isValidRedirectURI(client, redirectURI) {
|
||||
return nil, ErrInvalidRedirectURI
|
||||
}
|
||||
|
||||
codeBytes := make([]byte, 32)
|
||||
// Generate code | 生成授权码
|
||||
codeBytes := make([]byte, CodeLength)
|
||||
if _, err := rand.Read(codeBytes); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate authorization code: %w", err)
|
||||
}
|
||||
code := hex.EncodeToString(codeBytes)
|
||||
|
||||
@@ -134,29 +185,41 @@ func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID s
|
||||
Used: false,
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("satoken:oauth2:code:%s", code)
|
||||
key := s.getCodeKey(code)
|
||||
if err := s.storage.Set(key, authCode, s.codeExpiration); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to store authorization code: %w", err)
|
||||
}
|
||||
|
||||
return authCode, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForToken exchanges authorization code for access token | 用授权码换取访问令牌
|
||||
// isValidRedirectURI Checks if redirect URI is valid for client | 检查回调URI是否有效
|
||||
func (s *OAuth2Server) isValidRedirectURI(client *Client, redirectURI string) bool {
|
||||
for _, uri := range client.RedirectURIs {
|
||||
if uri == redirectURI {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExchangeCodeForToken Exchanges authorization code for access token | 用授权码换取访问令牌
|
||||
func (s *OAuth2Server) ExchangeCodeForToken(code, clientID, clientSecret, redirectURI string) (*AccessToken, error) {
|
||||
// Verify client credentials | 验证客户端凭证
|
||||
client, err := s.GetClient(clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if client.ClientSecret != clientSecret {
|
||||
return nil, fmt.Errorf("invalid client credentials")
|
||||
return nil, ErrInvalidClientCredentials
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("satoken:oauth2:code:%s", code)
|
||||
// Get authorization code | 获取授权码
|
||||
key := s.getCodeKey(code)
|
||||
data, err := s.storage.Get(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid authorization code")
|
||||
if err != nil || data == nil {
|
||||
return nil, ErrInvalidAuthCode
|
||||
}
|
||||
|
||||
authCode, ok := data.(*AuthorizationCode)
|
||||
@@ -164,45 +227,50 @@ func (s *OAuth2Server) ExchangeCodeForToken(code, clientID, clientSecret, redire
|
||||
return nil, fmt.Errorf("invalid code data")
|
||||
}
|
||||
|
||||
// Validate authorization code | 验证授权码
|
||||
if authCode.Used {
|
||||
return nil, fmt.Errorf("authorization code already used")
|
||||
return nil, ErrAuthCodeUsed
|
||||
}
|
||||
|
||||
if authCode.ClientID != clientID {
|
||||
return nil, fmt.Errorf("client mismatch")
|
||||
return nil, ErrClientMismatch
|
||||
}
|
||||
|
||||
if authCode.RedirectURI != redirectURI {
|
||||
return nil, fmt.Errorf("redirect_uri mismatch")
|
||||
return nil, ErrRedirectURIMismatch
|
||||
}
|
||||
|
||||
if time.Now().Unix() > authCode.CreateTime+authCode.ExpiresIn {
|
||||
s.storage.Delete(key)
|
||||
return nil, fmt.Errorf("authorization code expired")
|
||||
return nil, ErrAuthCodeExpired
|
||||
}
|
||||
|
||||
// Mark code as used | 标记为已使用
|
||||
authCode.Used = true
|
||||
s.storage.Set(key, authCode, time.Minute)
|
||||
|
||||
return s.generateAccessToken(authCode.UserID, authCode.ClientID, authCode.Scopes)
|
||||
}
|
||||
|
||||
// generateAccessToken Generates access token and refresh token | 生成访问令牌和刷新令牌
|
||||
func (s *OAuth2Server) generateAccessToken(userID, clientID string, scopes []string) (*AccessToken, error) {
|
||||
tokenBytes := make([]byte, 32)
|
||||
// Generate access token | 生成访问令牌
|
||||
tokenBytes := make([]byte, AccessTokenLength)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
accessToken := hex.EncodeToString(tokenBytes)
|
||||
|
||||
refreshBytes := make([]byte, 32)
|
||||
// Generate refresh token | 生成刷新令牌
|
||||
refreshBytes := make([]byte, RefreshTokenLength)
|
||||
if _, err := rand.Read(refreshBytes); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
refreshToken := hex.EncodeToString(refreshBytes)
|
||||
|
||||
token := &AccessToken{
|
||||
Token: accessToken,
|
||||
TokenType: "Bearer",
|
||||
TokenType: TokenTypeBearer,
|
||||
ExpiresIn: int64(s.tokenExpiration.Seconds()),
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: scopes,
|
||||
@@ -210,50 +278,58 @@ func (s *OAuth2Server) generateAccessToken(userID, clientID string, scopes []str
|
||||
ClientID: clientID,
|
||||
}
|
||||
|
||||
tokenKey := fmt.Sprintf("satoken:oauth2:token:%s", accessToken)
|
||||
refreshKey := fmt.Sprintf("satoken:oauth2:refresh:%s", refreshToken)
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
refreshKey := s.getRefreshKey(refreshToken)
|
||||
|
||||
// Store access token | 存储访问令牌
|
||||
if err := s.storage.Set(tokenKey, token, s.tokenExpiration); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to store access token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.storage.Set(refreshKey, token, 30*24*time.Hour); err != nil {
|
||||
return nil, err
|
||||
// Store refresh token | 存储刷新令牌
|
||||
if err := s.storage.Set(refreshKey, token, DefaultRefreshTTL); err != nil {
|
||||
return nil, fmt.Errorf("failed to store refresh token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateAccessToken validates access token | 验证访问令牌
|
||||
// ValidateAccessToken Validates access token | 验证访问令牌
|
||||
func (s *OAuth2Server) ValidateAccessToken(tokenString string) (*AccessToken, error) {
|
||||
key := fmt.Sprintf("satoken:oauth2:token:%s", tokenString)
|
||||
if tokenString == "" {
|
||||
return nil, ErrInvalidAccessToken
|
||||
}
|
||||
|
||||
key := s.getTokenKey(tokenString)
|
||||
data, err := s.storage.Get(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid access token")
|
||||
if err != nil || data == nil {
|
||||
return nil, ErrInvalidAccessToken
|
||||
}
|
||||
|
||||
token, ok := data.(*AccessToken)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token data")
|
||||
return nil, ErrInvalidTokenData
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌
|
||||
// RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌
|
||||
func (s *OAuth2Server) RefreshAccessToken(refreshToken, clientID, clientSecret string) (*AccessToken, error) {
|
||||
// Verify client credentials | 验证客户端凭证
|
||||
client, err := s.GetClient(clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if client.ClientSecret != clientSecret {
|
||||
return nil, fmt.Errorf("invalid client credentials")
|
||||
return nil, ErrInvalidClientCredentials
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("satoken:oauth2:refresh:%s", refreshToken)
|
||||
// Get refresh token | 获取刷新令牌
|
||||
key := s.getRefreshKey(refreshToken)
|
||||
data, err := s.storage.Get(key)
|
||||
if err != nil {
|
||||
if err != nil || data == nil {
|
||||
return nil, fmt.Errorf("invalid refresh token")
|
||||
}
|
||||
|
||||
@@ -263,28 +339,50 @@ func (s *OAuth2Server) RefreshAccessToken(refreshToken, clientID, clientSecret s
|
||||
}
|
||||
|
||||
if oldToken.ClientID != clientID {
|
||||
return nil, fmt.Errorf("client mismatch")
|
||||
return nil, ErrClientMismatch
|
||||
}
|
||||
|
||||
oldTokenKey := fmt.Sprintf("satoken:oauth2:token:%s", oldToken.Token)
|
||||
// Delete old access token | 删除旧的访问令牌
|
||||
oldTokenKey := s.getTokenKey(oldToken.Token)
|
||||
s.storage.Delete(oldTokenKey)
|
||||
|
||||
return s.generateAccessToken(oldToken.UserID, oldToken.ClientID, oldToken.Scopes)
|
||||
}
|
||||
|
||||
// RevokeToken revokes access token and its refresh token | 撤销访问令牌及其刷新令牌
|
||||
// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌
|
||||
func (s *OAuth2Server) RevokeToken(tokenString string) error {
|
||||
key := fmt.Sprintf("satoken:oauth2:token:%s", tokenString)
|
||||
if tokenString == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := s.getTokenKey(tokenString)
|
||||
data, err := s.storage.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, ok := data.(*AccessToken)
|
||||
if ok && token.RefreshToken != "" {
|
||||
refreshKey := fmt.Sprintf("satoken:oauth2:refresh:%s", token.RefreshToken)
|
||||
// Revoke refresh token if exists | 如果存在则撤销刷新令牌
|
||||
if token, ok := data.(*AccessToken); ok && token.RefreshToken != "" {
|
||||
refreshKey := s.getRefreshKey(token.RefreshToken)
|
||||
s.storage.Delete(refreshKey)
|
||||
}
|
||||
|
||||
return s.storage.Delete(key)
|
||||
}
|
||||
|
||||
// ============ Helper Methods | 辅助方法 ============
|
||||
|
||||
// getCodeKey Gets storage key for authorization code | 获取授权码的存储键
|
||||
func (s *OAuth2Server) getCodeKey(code string) string {
|
||||
return CodeKeyPrefix + code
|
||||
}
|
||||
|
||||
// getTokenKey Gets storage key for access token | 获取访问令牌的存储键
|
||||
func (s *OAuth2Server) getTokenKey(token string) string {
|
||||
return TokenKeyPrefix + token
|
||||
}
|
||||
|
||||
// getRefreshKey Gets storage key for refresh token | 获取刷新令牌的存储键
|
||||
func (s *OAuth2Server) getRefreshKey(refreshToken string) string {
|
||||
return RefreshKeyPrefix + refreshToken
|
||||
}
|
||||
|
||||
+39
-14
@@ -19,6 +19,7 @@ import (
|
||||
// Version Sa-Token-Go version | Sa-Token-Go版本
|
||||
const Version = "0.1.0"
|
||||
|
||||
// ============ Exported Types | 导出的类型 ============
|
||||
// Export main types and functions for external use | 导出主要类型和函数,方便外部使用
|
||||
|
||||
// Configuration related types | 配置相关类型
|
||||
@@ -96,59 +97,81 @@ const (
|
||||
GrantTypePassword = oauth2.GrantTypePassword
|
||||
)
|
||||
|
||||
// Utility functions | 工具函数
|
||||
// ============ Utility Functions | 工具函数 ============
|
||||
|
||||
var (
|
||||
RandomString = utils.RandomString
|
||||
IsEmpty = utils.IsEmpty
|
||||
IsNotEmpty = utils.IsNotEmpty
|
||||
DefaultString = utils.DefaultString
|
||||
// String utilities | 字符串工具
|
||||
RandomString = utils.RandomString
|
||||
RandomNumericString = utils.RandomNumericString
|
||||
RandomAlphanumeric = utils.RandomAlphanumeric
|
||||
IsEmpty = utils.IsEmpty
|
||||
IsNotEmpty = utils.IsNotEmpty
|
||||
DefaultString = utils.DefaultString
|
||||
|
||||
// Slice utilities | 切片工具
|
||||
ContainsString = utils.ContainsString
|
||||
RemoveString = utils.RemoveString
|
||||
UniqueStrings = utils.UniqueStrings
|
||||
MergeStrings = utils.MergeStrings
|
||||
MatchPattern = utils.MatchPattern
|
||||
FilterStrings = utils.FilterStrings
|
||||
MapStrings = utils.MapStrings
|
||||
|
||||
// Pattern matching | 模式匹配
|
||||
MatchPattern = utils.MatchPattern
|
||||
|
||||
// Duration utilities | 时长工具
|
||||
FormatDuration = utils.FormatDuration
|
||||
ParseDuration = utils.ParseDuration
|
||||
|
||||
// Hash & Encoding | 哈希和编码
|
||||
SHA256Hash = utils.SHA256Hash
|
||||
Base64Encode = utils.Base64Encode
|
||||
Base64Decode = utils.Base64Decode
|
||||
)
|
||||
|
||||
// DefaultConfig returns default configuration | 返回默认配置
|
||||
// ============ Factory Functions | 工厂函数 ============
|
||||
|
||||
// DefaultConfig Returns default configuration | 返回默认配置
|
||||
func DefaultConfig() *Config {
|
||||
return config.DefaultConfig()
|
||||
}
|
||||
|
||||
// NewManager creates a new authentication manager | 创建新的认证管理器
|
||||
// NewManager Creates a new authentication manager | 创建新的认证管理器
|
||||
func NewManager(storage Storage, cfg *Config) *Manager {
|
||||
return manager.NewManager(storage, cfg)
|
||||
}
|
||||
|
||||
// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文
|
||||
// NewContext Creates a new Sa-Token context | 创建新的Sa-Token上下文
|
||||
func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext {
|
||||
return context.NewContext(ctx, mgr)
|
||||
}
|
||||
|
||||
// NewSession creates a new session | 创建新的Session
|
||||
// NewSession Creates a new session | 创建新的Session
|
||||
func NewSession(id string, storage Storage, prefix string) *Session {
|
||||
return session.NewSession(id, storage, prefix)
|
||||
}
|
||||
|
||||
// LoadSession loads an existing session | 加载已存在的Session
|
||||
// LoadSession Loads an existing session | 加载已存在的Session
|
||||
func LoadSession(id string, storage Storage, prefix string) (*Session, error) {
|
||||
return session.Load(id, storage, prefix)
|
||||
}
|
||||
|
||||
// NewTokenGenerator creates a new token generator | 创建新的Token生成器
|
||||
// NewTokenGenerator Creates a new token generator | 创建新的Token生成器
|
||||
func NewTokenGenerator(cfg *Config) *TokenGenerator {
|
||||
return token.NewGenerator(cfg)
|
||||
}
|
||||
|
||||
// NewEventManager creates a new event manager | 创建新的事件管理器
|
||||
// NewEventManager Creates a new event manager | 创建新的事件管理器
|
||||
func NewEventManager() *EventManager {
|
||||
return listener.NewManager()
|
||||
}
|
||||
|
||||
// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置)
|
||||
// NewBuilder Creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置)
|
||||
func NewBuilder() *Builder {
|
||||
return builder.NewBuilder()
|
||||
}
|
||||
|
||||
// NewNonceManager Creates a new nonce manager | 创建新的Nonce管理器
|
||||
func NewNonceManager(storage Storage, ttl ...int64) *NonceManager {
|
||||
var duration time.Duration
|
||||
if len(ttl) > 0 && ttl[0] > 0 {
|
||||
@@ -157,10 +180,12 @@ func NewNonceManager(storage Storage, ttl ...int64) *NonceManager {
|
||||
return security.NewNonceManager(storage, duration)
|
||||
}
|
||||
|
||||
// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器
|
||||
func NewRefreshTokenManager(storage Storage, cfg *Config) *RefreshTokenManager {
|
||||
return security.NewRefreshTokenManager(storage, cfg)
|
||||
}
|
||||
|
||||
// NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器
|
||||
func NewOAuth2Server(storage Storage) *OAuth2Server {
|
||||
return oauth2.NewOAuth2Server(storage)
|
||||
}
|
||||
|
||||
+40
-13
@@ -23,6 +23,18 @@ import (
|
||||
// valid := manager.VerifyNonce(nonce) // true
|
||||
// valid = manager.VerifyNonce(nonce) // false (replay prevented)
|
||||
|
||||
// Constants for nonce | Nonce常量
|
||||
const (
|
||||
DefaultNonceTTL = 5 * time.Minute // Default nonce expiration | 默认nonce过期时间
|
||||
NonceLength = 32 // Nonce byte length | Nonce字节长度
|
||||
NonceKeyPrefix = "satoken:nonce:" // Storage key prefix | 存储键前缀
|
||||
)
|
||||
|
||||
// Error variables | 错误变量
|
||||
var (
|
||||
ErrInvalidNonce = fmt.Errorf("invalid or expired nonce")
|
||||
)
|
||||
|
||||
// NonceManager Nonce manager for anti-replay attacks | Nonce管理器,用于防重放攻击
|
||||
type NonceManager struct {
|
||||
storage adapter.Storage
|
||||
@@ -30,11 +42,11 @@ type NonceManager struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器
|
||||
// NewNonceManager Creates a new nonce manager | 创建新的Nonce管理器
|
||||
// ttl: time to live, default 5 minutes | 过期时间,默认5分钟
|
||||
func NewNonceManager(storage adapter.Storage, ttl time.Duration) *NonceManager {
|
||||
if ttl == 0 {
|
||||
ttl = 5 * time.Minute
|
||||
ttl = DefaultNonceTTL
|
||||
}
|
||||
return &NonceManager{
|
||||
storage: storage,
|
||||
@@ -42,31 +54,31 @@ func NewNonceManager(storage adapter.Storage, ttl time.Duration) *NonceManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Generate generates a new nonce and stores it | 生成新的nonce并存储
|
||||
// Generate Generates a new nonce and stores it | 生成新的nonce并存储
|
||||
// Returns 64-char hex string | 返回64字符的十六进制字符串
|
||||
func (nm *NonceManager) Generate() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
bytes := make([]byte, NonceLength)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
nonce := hex.EncodeToString(bytes)
|
||||
|
||||
key := fmt.Sprintf("satoken:nonce:%s", nonce)
|
||||
key := nm.getNonceKey(nonce)
|
||||
if err := nm.storage.Set(key, time.Now().Unix(), nm.ttl); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to store nonce: %w", err)
|
||||
}
|
||||
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
// Verify verifies nonce and consumes it (one-time use) | 验证nonce并消费它(一次性使用)
|
||||
// Verify Verifies nonce and consumes it (one-time use) | 验证nonce并消费它(一次性使用)
|
||||
// Returns false if nonce doesn't exist or already used | 如果nonce不存在或已使用则返回false
|
||||
func (nm *NonceManager) Verify(nonce string) bool {
|
||||
if nonce == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("satoken:nonce:%s", nonce)
|
||||
key := nm.getNonceKey(nonce)
|
||||
|
||||
nm.mu.Lock()
|
||||
defer nm.mu.Unlock()
|
||||
@@ -79,14 +91,29 @@ func (nm *NonceManager) Verify(nonce string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// VerifyAndConsume verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误
|
||||
// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误
|
||||
func (nm *NonceManager) VerifyAndConsume(nonce string) error {
|
||||
if !nm.Verify(nonce) {
|
||||
return fmt.Errorf("invalid or expired nonce")
|
||||
return ErrInvalidNonce
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clean cleans expired nonces (handled by storage TTL) | 清理过期的nonce(由存储的TTL处理)
|
||||
func (nm *NonceManager) Clean() {
|
||||
// IsValid Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费)
|
||||
func (nm *NonceManager) IsValid(nonce string) bool {
|
||||
if nonce == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
key := nm.getNonceKey(nonce)
|
||||
|
||||
nm.mu.RLock()
|
||||
defer nm.mu.RUnlock()
|
||||
|
||||
return nm.storage.Exists(key)
|
||||
}
|
||||
|
||||
// getNonceKey Gets storage key for nonce | 获取nonce的存储键
|
||||
func (nm *NonceManager) getNonceKey(nonce string) string {
|
||||
return NonceKeyPrefix + nonce
|
||||
}
|
||||
|
||||
@@ -25,6 +25,21 @@ import (
|
||||
// // ... access token expires ...
|
||||
// newInfo, _ := manager.RefreshAccessToken(tokenInfo.RefreshToken)
|
||||
|
||||
// Constants for refresh token | 刷新令牌常量
|
||||
const (
|
||||
DefaultRefreshTTL = 30 * 24 * time.Hour // 30 days | 30天
|
||||
DefaultAccessTTL = 2 * time.Hour // 2 hours | 2小时
|
||||
RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度
|
||||
RefreshKeyPrefix = "satoken:refresh:" // Storage key prefix | 存储键前缀
|
||||
)
|
||||
|
||||
// Error variables | 错误变量
|
||||
var (
|
||||
ErrInvalidRefreshToken = fmt.Errorf("invalid refresh token")
|
||||
ErrRefreshTokenExpired = fmt.Errorf("refresh token expired")
|
||||
ErrInvalidRefreshData = fmt.Errorf("invalid refresh token data")
|
||||
)
|
||||
|
||||
// RefreshTokenInfo refresh token information | 刷新令牌信息
|
||||
type RefreshTokenInfo struct {
|
||||
RefreshToken string // Refresh token (long-lived) | 刷新令牌(长期有效)
|
||||
@@ -35,7 +50,7 @@ type RefreshTokenInfo struct {
|
||||
ExpireTime int64 // Expiration timestamp | 过期时间戳
|
||||
}
|
||||
|
||||
// RefreshTokenManager refresh token manager | 刷新令牌管理器
|
||||
// RefreshTokenManager Refresh token manager | 刷新令牌管理器
|
||||
type RefreshTokenManager struct {
|
||||
storage adapter.Storage
|
||||
tokenGen *token.Generator
|
||||
@@ -43,34 +58,39 @@ type RefreshTokenManager struct {
|
||||
accessTTL time.Duration // Access token TTL (configurable) | 访问令牌有效期(可配置)
|
||||
}
|
||||
|
||||
// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器
|
||||
// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器
|
||||
// cfg: configuration, uses Timeout for access token TTL | 配置,使用Timeout作为访问令牌有效期
|
||||
func NewRefreshTokenManager(storage adapter.Storage, cfg *config.Config) *RefreshTokenManager {
|
||||
refreshTTL := 30 * 24 * time.Hour // 30 days | 30天
|
||||
accessTTL := time.Duration(cfg.Timeout) * time.Second
|
||||
|
||||
if accessTTL == 0 {
|
||||
accessTTL = 2 * time.Hour // Default 2 hours | 默认2小时
|
||||
accessTTL = DefaultAccessTTL
|
||||
}
|
||||
|
||||
return &RefreshTokenManager{
|
||||
storage: storage,
|
||||
tokenGen: token.NewGenerator(cfg),
|
||||
refreshTTL: refreshTTL,
|
||||
refreshTTL: DefaultRefreshTTL,
|
||||
accessTTL: accessTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTokenPair generates access token and refresh token pair | 生成访问令牌和刷新令牌对
|
||||
// GenerateTokenPair Generates access token and refresh token pair | 生成访问令牌和刷新令牌对
|
||||
func (rtm *RefreshTokenManager) GenerateTokenPair(loginID, device string) (*RefreshTokenInfo, error) {
|
||||
accessToken, err := rtm.tokenGen.Generate(loginID, device)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if loginID == "" {
|
||||
return nil, fmt.Errorf("loginID cannot be empty")
|
||||
}
|
||||
|
||||
refreshTokenBytes := make([]byte, 32)
|
||||
// Generate access token | 生成访问令牌
|
||||
accessToken, err := rtm.tokenGen.Generate(loginID, device)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
// Generate refresh token | 生成刷新令牌
|
||||
refreshTokenBytes := make([]byte, RefreshTokenLength)
|
||||
if _, err := rand.Read(refreshTokenBytes); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
refreshToken := hex.EncodeToString(refreshTokenBytes)
|
||||
|
||||
@@ -84,66 +104,95 @@ func (rtm *RefreshTokenManager) GenerateTokenPair(loginID, device string) (*Refr
|
||||
ExpireTime: now.Add(rtm.refreshTTL).Unix(),
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
|
||||
key := rtm.getRefreshKey(refreshToken)
|
||||
if err := rtm.storage.Set(key, info, rtm.refreshTTL); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to store refresh token: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken generates new access token using refresh token | 使用刷新令牌生成新的访问令牌
|
||||
// RefreshAccessToken Generates new access token using refresh token | 使用刷新令牌生成新的访问令牌
|
||||
func (rtm *RefreshTokenManager) RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) {
|
||||
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
|
||||
if refreshToken == "" {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
key := rtm.getRefreshKey(refreshToken)
|
||||
|
||||
data, err := rtm.storage.Get(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid refresh token")
|
||||
if err != nil || data == nil {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
oldInfo, ok := data.(*RefreshTokenInfo)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid refresh token data")
|
||||
return nil, ErrInvalidRefreshData
|
||||
}
|
||||
|
||||
// Check expiration | 检查是否过期
|
||||
if time.Now().Unix() > oldInfo.ExpireTime {
|
||||
rtm.storage.Delete(key)
|
||||
return nil, fmt.Errorf("refresh token expired")
|
||||
return nil, ErrRefreshTokenExpired
|
||||
}
|
||||
|
||||
// Generate new access token | 生成新的访问令牌
|
||||
newAccessToken, err := rtm.tokenGen.Generate(oldInfo.LoginID, oldInfo.Device)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate new access token: %w", err)
|
||||
}
|
||||
|
||||
oldInfo.AccessToken = newAccessToken
|
||||
|
||||
// Update storage | 更新存储
|
||||
if err := rtm.storage.Set(key, oldInfo, rtm.refreshTTL); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to update refresh token: %w", err)
|
||||
}
|
||||
|
||||
return oldInfo, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌
|
||||
// RevokeRefreshToken Revokes a refresh token | 撤销刷新令牌
|
||||
func (rtm *RefreshTokenManager) RevokeRefreshToken(refreshToken string) error {
|
||||
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
|
||||
if refreshToken == "" {
|
||||
return nil
|
||||
}
|
||||
key := rtm.getRefreshKey(refreshToken)
|
||||
return rtm.storage.Delete(key)
|
||||
}
|
||||
|
||||
// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息
|
||||
// GetRefreshTokenInfo Gets refresh token information | 获取刷新令牌信息
|
||||
func (rtm *RefreshTokenManager) GetRefreshTokenInfo(refreshToken string) (*RefreshTokenInfo, error) {
|
||||
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
|
||||
if refreshToken == "" {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
key := rtm.getRefreshKey(refreshToken)
|
||||
|
||||
data, err := rtm.storage.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err != nil || data == nil {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
info, ok := data.(*RefreshTokenInfo)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid refresh token data")
|
||||
return nil, ErrInvalidRefreshData
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// IsValid Checks if refresh token is valid | 检查刷新令牌是否有效
|
||||
func (rtm *RefreshTokenManager) IsValid(refreshToken string) bool {
|
||||
info, err := rtm.GetRefreshTokenInfo(refreshToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Now().Unix() <= info.ExpireTime
|
||||
}
|
||||
|
||||
// getRefreshKey Gets storage key for refresh token | 获取刷新令牌的存储键
|
||||
func (rtm *RefreshTokenManager) getRefreshKey(refreshToken string) string {
|
||||
return RefreshKeyPrefix + refreshToken
|
||||
}
|
||||
|
||||
+69
-26
@@ -9,29 +9,46 @@ import (
|
||||
"github.com/click33/sa-token-go/core/adapter"
|
||||
)
|
||||
|
||||
// Session session object for storing user data | 会话对象,用于存储用户数据
|
||||
// Constants for session keys | Session键常量
|
||||
const (
|
||||
SessionKeyPrefix = "session:" // Storage key prefix | 存储键前缀
|
||||
)
|
||||
|
||||
// Error variables | 错误变量
|
||||
var (
|
||||
ErrSessionNotFound = fmt.Errorf("session not found")
|
||||
ErrInvalidSessionData = fmt.Errorf("invalid session data")
|
||||
)
|
||||
|
||||
// Session Session object for storing user data | 会话对象,用于存储用户数据
|
||||
type Session struct {
|
||||
ID string `json:"id"` // Session ID | Session标识
|
||||
CreateTime int64 `json:"createTime"` // Creation time | 创建时间
|
||||
Data map[string]interface{} `json:"data"` // Session data | 数据
|
||||
mu sync.RWMutex `json:"-"` // Read-write lock | 读写锁
|
||||
storage adapter.Storage `json:"-"` // Storage backend | 存储
|
||||
prefix string `json:"-"` // Key prefix | 键前缀
|
||||
ID string `json:"id"` // Session ID | Session标识
|
||||
CreateTime int64 `json:"createTime"` // Creation time | 创建时间
|
||||
Data map[string]any `json:"data"` // Session data | 数据
|
||||
mu sync.RWMutex `json:"-"` // Read-write lock | 读写锁
|
||||
storage adapter.Storage `json:"-"` // Storage backend | 存储
|
||||
prefix string `json:"-"` // Key prefix | 键前缀
|
||||
}
|
||||
|
||||
// NewSession creates a new session | 创建新的Session
|
||||
// NewSession Creates a new session | 创建新的Session
|
||||
func NewSession(id string, storage adapter.Storage, prefix string) *Session {
|
||||
return &Session{
|
||||
ID: id,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Data: make(map[string]interface{}),
|
||||
Data: make(map[string]any),
|
||||
storage: storage,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets value | 设置值
|
||||
func (s *Session) Set(key string, value interface{}) error {
|
||||
// ============ Data Operations | 数据操作 ============
|
||||
|
||||
// Set Sets value | 设置值
|
||||
func (s *Session) Set(key string, value any) error {
|
||||
if key == "" {
|
||||
return fmt.Errorf("key cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -39,8 +56,8 @@ func (s *Session) Set(key string, value interface{}) error {
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// Get gets value | 获取值
|
||||
func (s *Session) Get(key string) (interface{}, bool) {
|
||||
// Get Gets value | 获取值
|
||||
func (s *Session) Get(key string) (any, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -116,16 +133,16 @@ func (s *Session) Delete(key string) error {
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// Clear 清空所有数据
|
||||
// Clear Clears all data | 清空所有数据
|
||||
func (s *Session) Clear() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.Data = make(map[string]interface{})
|
||||
s.Data = make(map[string]any)
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// Keys 获取所有键
|
||||
// Keys Gets all keys | 获取所有键
|
||||
func (s *Session) Keys() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -137,7 +154,7 @@ func (s *Session) Keys() []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
// Size 获取数据数量
|
||||
// Size Gets data count | 获取数据数量
|
||||
func (s *Session) Size() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -145,32 +162,55 @@ func (s *Session) Size() int {
|
||||
return len(s.Data)
|
||||
}
|
||||
|
||||
// save 保存到存储
|
||||
// IsEmpty Checks if session has no data | 检查Session是否为空
|
||||
func (s *Session) IsEmpty() bool {
|
||||
return s.Size() == 0
|
||||
}
|
||||
|
||||
// ============ Internal Methods | 内部方法 ============
|
||||
|
||||
// save Saves session to storage | 保存到存储
|
||||
func (s *Session) save() error {
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session: %w", err)
|
||||
}
|
||||
|
||||
key := s.prefix + "session:" + s.ID
|
||||
key := s.getStorageKey()
|
||||
return s.storage.Set(key, string(data), 0)
|
||||
}
|
||||
|
||||
// Load 从存储加载
|
||||
// getStorageKey Gets storage key for this session | 获取Session的存储键
|
||||
func (s *Session) getStorageKey() string {
|
||||
return s.prefix + SessionKeyPrefix + s.ID
|
||||
}
|
||||
|
||||
// ============ Static Methods | 静态方法 ============
|
||||
|
||||
// Load Loads session from storage | 从存储加载
|
||||
func Load(id string, storage adapter.Storage, prefix string) (*Session, error) {
|
||||
key := prefix + "session:" + id
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("session id cannot be empty")
|
||||
}
|
||||
|
||||
key := prefix + SessionKeyPrefix + id
|
||||
data, err := storage.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
dataStr, ok := data.(string)
|
||||
if !ok {
|
||||
return nil, ErrInvalidSessionData
|
||||
}
|
||||
|
||||
var session Session
|
||||
if err := json.Unmarshal([]byte(data.(string)), &session); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
||||
if err := json.Unmarshal([]byte(dataStr), &session); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidSessionData, err)
|
||||
}
|
||||
|
||||
session.storage = storage
|
||||
@@ -178,8 +218,11 @@ func Load(id string, storage adapter.Storage, prefix string) (*Session, error) {
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// Destroy 销毁Session
|
||||
// Destroy Destroys session | 销毁Session
|
||||
func (s *Session) Destroy() error {
|
||||
key := s.prefix + "session:" + s.ID
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.getStorageKey()
|
||||
return s.storage.Delete(key)
|
||||
}
|
||||
|
||||
+115
-47
@@ -3,36 +3,61 @@ package token
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/click33/sa-token-go/core/config"
|
||||
"github.com/click33/sa-token-go/core/utils"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Constants for token generation | Token生成常量
|
||||
const (
|
||||
DefaultJWTSecret = "default-secret-key" // Should be overridden in production | 生产环境应覆盖
|
||||
TikTokenLength = 11 // TikTok-style short ID length | Tik风格短ID长度
|
||||
TikCharset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
HashRandomBytesLen = 16 // Random bytes length for hash token | 哈希Token的随机字节长度
|
||||
TimestampRandomLen = 8 // Random bytes length for timestamp token | 时间戳Token的随机字节长度
|
||||
DefaultSimpleLength = 16 // Default simple token length | 默认简单Token长度
|
||||
)
|
||||
|
||||
// Error variables | 错误变量
|
||||
var (
|
||||
ErrInvalidToken = fmt.Errorf("invalid token")
|
||||
ErrUnexpectedSigningMethod = fmt.Errorf("unexpected signing method")
|
||||
)
|
||||
|
||||
// Generator Token generator | Token生成器
|
||||
type Generator struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewGenerator creates a new token generator | 创建新的Token生成器
|
||||
// NewGenerator Creates a new token generator | 创建新的Token生成器
|
||||
func NewGenerator(cfg *config.Config) *Generator {
|
||||
if cfg == nil {
|
||||
cfg = config.DefaultConfig()
|
||||
}
|
||||
return &Generator{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Generate generates token based on configured style | 根据配置的风格生成Token
|
||||
// ============ Public Methods | 公共方法 ============
|
||||
|
||||
// Generate Generates token based on configured style | 根据配置的风格生成Token
|
||||
func (g *Generator) Generate(loginID string, device string) (string, error) {
|
||||
if loginID == "" {
|
||||
return "", fmt.Errorf("loginID cannot be empty")
|
||||
}
|
||||
|
||||
switch g.config.TokenStyle {
|
||||
case config.TokenStyleUUID:
|
||||
return g.generateUUID()
|
||||
case config.TokenStyleSimple:
|
||||
return g.generateSimple(16)
|
||||
return g.generateSimple(DefaultSimpleLength)
|
||||
case config.TokenStyleRandom32:
|
||||
return g.generateSimple(32)
|
||||
case config.TokenStyleRandom64:
|
||||
@@ -52,21 +77,31 @@ func (g *Generator) Generate(loginID string, device string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// generateUUID generates UUID token | 生成UUID Token
|
||||
// ============ Token Generation Methods | Token生成方法 ============
|
||||
|
||||
// generateUUID Generates UUID token | 生成UUID Token
|
||||
func (g *Generator) generateUUID() (string, error) {
|
||||
return uuid.New().String(), nil
|
||||
}
|
||||
|
||||
// generateSimple generates simple random string token | 生成简单随机字符串Token
|
||||
func (g *Generator) generateSimple(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
u, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate UUID: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// generateJWT generates JWT token | 生成JWT Token
|
||||
// generateSimple Generates simple random string token | 生成简单随机字符串Token
|
||||
func (g *Generator) generateSimple(length int) (string, error) {
|
||||
if length <= 0 {
|
||||
length = DefaultSimpleLength
|
||||
}
|
||||
|
||||
token := utils.RandomString(length)
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("failed to generate random string")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// generateJWT Generates JWT token | 生成JWT Token
|
||||
func (g *Generator) generateJWT(loginID string, device string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
@@ -75,71 +110,105 @@ func (g *Generator) generateJWT(loginID string, device string) (string, error) {
|
||||
"iat": now.Unix(),
|
||||
}
|
||||
|
||||
// Add expiration if timeout is configured | 如果配置了超时时间则添加过期时间
|
||||
if g.config.Timeout > 0 {
|
||||
claims["exp"] = now.Add(time.Duration(g.config.Timeout) * time.Second).Unix()
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
secretKey := g.config.JwtSecretKey
|
||||
if secretKey == "" {
|
||||
secretKey = "default-secret-key"
|
||||
secretKey := g.getJWTSecret()
|
||||
|
||||
signedToken, err := token.SignedString([]byte(secretKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign JWT token: %w", err)
|
||||
}
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
// ParseJWT parses JWT token and returns claims | 解析JWT Token并返回声明
|
||||
// getJWTSecret Gets JWT secret key with fallback | 获取JWT密钥(带默认值)
|
||||
func (g *Generator) getJWTSecret() string {
|
||||
if g.config.JwtSecretKey != "" {
|
||||
return g.config.JwtSecretKey
|
||||
}
|
||||
return DefaultJWTSecret
|
||||
}
|
||||
|
||||
// ============ JWT Helper Methods | JWT辅助方法 ============
|
||||
|
||||
// ParseJWT Parses JWT token and returns claims | 解析JWT Token并返回声明
|
||||
func (g *Generator) ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
||||
secretKey := g.config.JwtSecretKey
|
||||
if secretKey == "" {
|
||||
secretKey = "default-secret-key"
|
||||
if tokenStr == "" {
|
||||
return nil, fmt.Errorf("token string cannot be empty")
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
secretKey := g.getJWTSecret()
|
||||
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) {
|
||||
// Verify signing method | 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
return nil, fmt.Errorf("%w: %v", ErrUnexpectedSigningMethod, token.Header["alg"])
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// ValidateJWT validates JWT token | 验证JWT Token
|
||||
// ValidateJWT Validates JWT token | 验证JWT Token
|
||||
func (g *Generator) ValidateJWT(tokenStr string) error {
|
||||
_, err := g.ParseJWT(tokenStr)
|
||||
return err
|
||||
}
|
||||
|
||||
// generateHash generates SHA256 hash-based token | 生成SHA256哈希风格Token
|
||||
func (g *Generator) generateHash(loginID string, device string) (string, error) {
|
||||
// Combine loginID, device, timestamp and random bytes
|
||||
// 组合 loginID、device、时间戳和随机字节
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
// GetLoginIDFromJWT Extracts login ID from JWT token | 从JWT Token中提取登录ID
|
||||
func (g *Generator) GetLoginIDFromJWT(tokenStr string) (string, error) {
|
||||
claims, err := g.ParseJWT(tokenStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data := fmt.Sprintf("%s:%s:%d:%s", loginID, device, time.Now().UnixNano(), hex.EncodeToString(randomBytes))
|
||||
loginID, ok := claims["loginId"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("loginId not found in token claims")
|
||||
}
|
||||
|
||||
return loginID, nil
|
||||
}
|
||||
|
||||
// generateHash Generates SHA256 hash-based token | 生成SHA256哈希风格Token
|
||||
func (g *Generator) generateHash(loginID string, device string) (string, error) {
|
||||
// Combine loginID, device, timestamp and random bytes | 组合 loginID、device、时间戳和随机字节
|
||||
randomBytes := make([]byte, HashRandomBytesLen)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Create hash input | 创建哈希输入
|
||||
data := fmt.Sprintf("%s:%s:%d:%s",
|
||||
loginID,
|
||||
device,
|
||||
time.Now().UnixNano(),
|
||||
hex.EncodeToString(randomBytes))
|
||||
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:]), nil
|
||||
}
|
||||
|
||||
// generateTimestamp generates timestamp-based token | 生成时间戳风格Token
|
||||
// generateTimestamp Generates timestamp-based token | 生成时间戳风格Token
|
||||
func (g *Generator) generateTimestamp(loginID string, device string) (string, error) {
|
||||
// Format: timestamp_loginID_random
|
||||
// 格式:时间戳_loginID_随机数
|
||||
randomBytes := make([]byte, 8)
|
||||
// Format: timestamp_loginID_random | 格式:时间戳_loginID_随机数
|
||||
randomBytes := make([]byte, TimestampRandomLen)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().UnixMilli()
|
||||
@@ -147,18 +216,17 @@ func (g *Generator) generateTimestamp(loginID string, device string) (string, er
|
||||
return fmt.Sprintf("%d_%s_%s", timestamp, loginID, random), nil
|
||||
}
|
||||
|
||||
// generateTik generates short ID style token (like TikTok) | 生成Tik风格短ID Token(类似抖音)
|
||||
// generateTik Generates short ID style token (like TikTok) | 生成Tik风格短ID Token(类似抖音)
|
||||
func (g *Generator) generateTik() (string, error) {
|
||||
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
const length = 11 // TikTok-style short ID length | 抖音风格短ID长度
|
||||
result := make([]byte, TikTokenLength)
|
||||
charsetLen := int64(len(TikCharset))
|
||||
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(charsetLen))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to generate random number: %w", err)
|
||||
}
|
||||
result[i] = charset[num.Int64()]
|
||||
result[i] = TikCharset[num.Int64()]
|
||||
}
|
||||
|
||||
return string(result), nil
|
||||
|
||||
+408
-37
@@ -2,19 +2,100 @@ package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Constants for time durations | 时间常量
|
||||
const (
|
||||
Second = 1
|
||||
Minute = 60 * Second
|
||||
Hour = 60 * Minute
|
||||
Day = 24 * Hour
|
||||
Week = 7 * Day
|
||||
)
|
||||
|
||||
// Constants for string operations | 字符串操作常量
|
||||
const (
|
||||
DefaultSeparator = ","
|
||||
WildcardChar = "*"
|
||||
)
|
||||
|
||||
// ============ Random Generation | 随机生成 ============
|
||||
|
||||
// RandomString generates random string of specified length | 生成指定长度的随机字符串
|
||||
func RandomString(length int) string {
|
||||
bytes := make([]byte, length)
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate required byte length (base64 expands by ~33%)
|
||||
byteLen := (length * 3) / 4
|
||||
if byteLen < length {
|
||||
byteLen = length
|
||||
}
|
||||
|
||||
bytes := make([]byte, byteLen)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return ""
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length]
|
||||
|
||||
encoded := base64.URLEncoding.EncodeToString(bytes)
|
||||
// Remove padding and trim to exact length
|
||||
encoded = strings.TrimRight(encoded, "=")
|
||||
if len(encoded) > length {
|
||||
return encoded[:length]
|
||||
}
|
||||
return encoded
|
||||
}
|
||||
|
||||
// RandomNumericString generates random numeric string | 生成随机数字字符串
|
||||
func RandomNumericString(length int) string {
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const digits = "0123456789"
|
||||
result := make([]byte, length)
|
||||
max := big.NewInt(int64(len(digits)))
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
n, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
result[i] = digits[n.Int64()]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// RandomAlphanumeric generates random alphanumeric string | 生成随机字母数字字符串
|
||||
func RandomAlphanumeric(length int) string {
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
result := make([]byte, length)
|
||||
max := big.NewInt(int64(len(chars)))
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
n, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
result[i] = chars[n.Int64()]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// IsEmpty checks if string is empty | 检查字符串是否为空
|
||||
@@ -47,7 +128,7 @@ func ContainsString(slice []string, item string) bool {
|
||||
|
||||
// RemoveString removes item from string slice | 从字符串数组中移除指定字符串
|
||||
func RemoveString(slice []string, item string) []string {
|
||||
result := make([]string, 0)
|
||||
result := make([]string, 0, len(slice))
|
||||
for _, s := range slice {
|
||||
if s != item {
|
||||
result = append(result, s)
|
||||
@@ -58,8 +139,12 @@ func RemoveString(slice []string, item string) []string {
|
||||
|
||||
// UniqueStrings removes duplicates from string slice | 字符串数组去重
|
||||
func UniqueStrings(slice []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
result := make([]string, 0)
|
||||
if len(slice) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(slice))
|
||||
result := make([]string, 0, len(slice))
|
||||
for _, s := range slice {
|
||||
if !seen[s] {
|
||||
seen[s] = true
|
||||
@@ -69,19 +154,53 @@ func UniqueStrings(slice []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// MergeStrings 合并多个字符串数组并去重
|
||||
// FilterStrings filters string slice by predicate | 根据条件过滤字符串数组
|
||||
func FilterStrings(slice []string, predicate func(string) bool) []string {
|
||||
result := make([]string, 0, len(slice))
|
||||
for _, s := range slice {
|
||||
if predicate(s) {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MapStrings applies function to each string in slice | 对数组中每个字符串应用函数
|
||||
func MapStrings(slice []string, mapper func(string) string) []string {
|
||||
result := make([]string, len(slice))
|
||||
for i, s := range slice {
|
||||
result[i] = mapper(s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MergeStrings Merges multiple string slices and removes duplicates | 合并多个字符串数组并去重
|
||||
func MergeStrings(slices ...[]string) []string {
|
||||
result := make([]string, 0)
|
||||
if len(slices) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Pre-calculate total capacity
|
||||
totalLen := 0
|
||||
for _, slice := range slices {
|
||||
totalLen += len(slice)
|
||||
}
|
||||
|
||||
result := make([]string, 0, totalLen)
|
||||
for _, slice := range slices {
|
||||
result = append(result, slice...)
|
||||
}
|
||||
return UniqueStrings(result)
|
||||
}
|
||||
|
||||
// SplitAndTrim 分割字符串并去除空格
|
||||
// SplitAndTrim Splits string and trims whitespace | 分割字符串并去除空格
|
||||
func SplitAndTrim(s, sep string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
parts := strings.Split(s, sep)
|
||||
result := make([]string, 0)
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
@@ -91,6 +210,17 @@ func SplitAndTrim(s, sep string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// JoinNonEmpty Joins non-empty strings | 连接非空字符串
|
||||
func JoinNonEmpty(sep string, strs ...string) string {
|
||||
nonEmpty := make([]string, 0, len(strs))
|
||||
for _, s := range strs {
|
||||
if IsNotEmpty(s) {
|
||||
nonEmpty = append(nonEmpty, s)
|
||||
}
|
||||
}
|
||||
return strings.Join(nonEmpty, sep)
|
||||
}
|
||||
|
||||
// GetStructTag 获取结构体字段的标签值
|
||||
func GetStructTag(field reflect.StructField, tag string) string {
|
||||
return field.Tag.Get(tag)
|
||||
@@ -120,18 +250,18 @@ func ParseRoleTag(tag string) []string {
|
||||
return SplitAndTrim(tag, ",")
|
||||
}
|
||||
|
||||
// MatchPattern 模式匹配(支持通配符*)
|
||||
// MatchPattern Pattern matching with wildcard support | 模式匹配(支持通配符*)
|
||||
func MatchPattern(pattern, str string) bool {
|
||||
if pattern == "*" {
|
||||
if pattern == WildcardChar {
|
||||
return true
|
||||
}
|
||||
|
||||
if !strings.Contains(pattern, "*") {
|
||||
if !strings.Contains(pattern, WildcardChar) {
|
||||
return pattern == str
|
||||
}
|
||||
|
||||
// 简单的通配符匹配
|
||||
parts := strings.Split(pattern, "*")
|
||||
// Simple wildcard matching | 简单的通配符匹配
|
||||
parts := strings.Split(pattern, WildcardChar)
|
||||
if len(parts) == 2 {
|
||||
prefix, suffix := parts[0], parts[1]
|
||||
if prefix != "" && !strings.HasPrefix(str, prefix) {
|
||||
@@ -143,61 +273,302 @@ func MatchPattern(pattern, str string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
// Complex pattern with multiple wildcards | 复杂模式(多个通配符)
|
||||
pos := 0
|
||||
for i, part := range parts {
|
||||
if i == 0 && part != "" {
|
||||
if !strings.HasPrefix(str, part) {
|
||||
return false
|
||||
}
|
||||
pos += len(part)
|
||||
continue
|
||||
}
|
||||
if i == len(parts)-1 && part != "" {
|
||||
return strings.HasSuffix(str, part)
|
||||
}
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
idx := strings.Index(str[pos:], part)
|
||||
if idx == -1 {
|
||||
return false
|
||||
}
|
||||
pos += idx + len(part)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// FormatDuration 格式化时间段(秒)为人类可读格式
|
||||
// ============ Time & Duration | 时间和时长 ============
|
||||
|
||||
// FormatDuration Formats duration in seconds to human-readable format | 格式化时间段(秒)为人类可读格式
|
||||
func FormatDuration(seconds int64) string {
|
||||
if seconds < 0 {
|
||||
return "永久"
|
||||
}
|
||||
|
||||
if seconds < 60 {
|
||||
if seconds == 0 {
|
||||
return "0秒"
|
||||
}
|
||||
|
||||
if seconds < Minute {
|
||||
return fmt.Sprintf("%d秒", seconds)
|
||||
}
|
||||
|
||||
if seconds < 3600 {
|
||||
minutes := seconds / 60
|
||||
if seconds < Hour {
|
||||
minutes := seconds / Minute
|
||||
return fmt.Sprintf("%d分钟", minutes)
|
||||
}
|
||||
|
||||
if seconds < 86400 {
|
||||
hours := seconds / 3600
|
||||
if seconds < Day {
|
||||
hours := seconds / Hour
|
||||
return fmt.Sprintf("%d小时", hours)
|
||||
}
|
||||
|
||||
days := seconds / 86400
|
||||
return fmt.Sprintf("%d天", days)
|
||||
if seconds < Week {
|
||||
days := seconds / Day
|
||||
return fmt.Sprintf("%d天", days)
|
||||
}
|
||||
|
||||
weeks := seconds / Week
|
||||
return fmt.Sprintf("%d周", weeks)
|
||||
}
|
||||
|
||||
// ParseDuration 解析人类可读的时间段为秒
|
||||
// ParseDuration Parses human-readable duration to seconds | 解析人类可读的时间段为秒
|
||||
func ParseDuration(duration string) int64 {
|
||||
duration = strings.ToLower(strings.TrimSpace(duration))
|
||||
|
||||
if strings.HasSuffix(duration, "s") || strings.HasSuffix(duration, "秒") {
|
||||
return parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "s"), "秒"))
|
||||
if duration == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
if strings.HasSuffix(duration, "m") || strings.HasSuffix(duration, "分") {
|
||||
minutes := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "m"), "分"))
|
||||
return minutes * 60
|
||||
}
|
||||
|
||||
if strings.HasSuffix(duration, "h") || strings.HasSuffix(duration, "时") {
|
||||
hours := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "h"), "时"))
|
||||
return hours * 3600
|
||||
// Week | 周
|
||||
if strings.HasSuffix(duration, "w") || strings.HasSuffix(duration, "周") {
|
||||
weeks := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "w"), "周"))
|
||||
return weeks * Week
|
||||
}
|
||||
|
||||
// Day | 天
|
||||
if strings.HasSuffix(duration, "d") || strings.HasSuffix(duration, "天") {
|
||||
days := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "d"), "天"))
|
||||
return days * 86400
|
||||
return days * Day
|
||||
}
|
||||
|
||||
// Hour | 小时
|
||||
if strings.HasSuffix(duration, "h") || strings.HasSuffix(duration, "时") || strings.HasSuffix(duration, "小时") {
|
||||
hours := parseInt64(strings.TrimSuffix(strings.TrimSuffix(strings.TrimSuffix(duration, "h"), "时"), "小时"))
|
||||
return hours * Hour
|
||||
}
|
||||
|
||||
// Minute | 分钟
|
||||
if strings.HasSuffix(duration, "m") || strings.HasSuffix(duration, "分") || strings.HasSuffix(duration, "分钟") {
|
||||
minutes := parseInt64(strings.TrimSuffix(strings.TrimSuffix(strings.TrimSuffix(duration, "m"), "分"), "分钟"))
|
||||
return minutes * Minute
|
||||
}
|
||||
|
||||
// Second | 秒
|
||||
if strings.HasSuffix(duration, "s") || strings.HasSuffix(duration, "秒") {
|
||||
return parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "s"), "秒"))
|
||||
}
|
||||
|
||||
return parseInt64(duration)
|
||||
}
|
||||
|
||||
// TimestampToTime Converts Unix timestamp to time.Time | Unix时间戳转time.Time
|
||||
func TimestampToTime(timestamp int64) time.Time {
|
||||
return time.Unix(timestamp, 0)
|
||||
}
|
||||
|
||||
// TimeToTimestamp Converts time.Time to Unix timestamp | time.Time转Unix时间戳
|
||||
func TimeToTimestamp(t time.Time) int64 {
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// ============ Type Conversion | 类型转换 ============
|
||||
|
||||
// parseInt64 Parses string to int64 | 将字符串解析为int64
|
||||
func parseInt64(s string) int64 {
|
||||
var result int64
|
||||
fmt.Sscanf(s, "%d", &result)
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
result, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToInt Converts any to int | 将any转换为int
|
||||
func ToInt(v any) (int, error) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, nil
|
||||
case int32:
|
||||
return int(val), nil
|
||||
case int64:
|
||||
return int(val), nil
|
||||
case float32:
|
||||
return int(val), nil
|
||||
case float64:
|
||||
return int(val), nil
|
||||
case string:
|
||||
return strconv.Atoi(val)
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot convert %T to int", v)
|
||||
}
|
||||
}
|
||||
|
||||
// ToInt64 Converts any to int64 | 将any转换为int64
|
||||
func ToInt64(v any) (int64, error) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return int64(val), nil
|
||||
case int32:
|
||||
return int64(val), nil
|
||||
case int64:
|
||||
return val, nil
|
||||
case float32:
|
||||
return int64(val), nil
|
||||
case float64:
|
||||
return int64(val), nil
|
||||
case string:
|
||||
return strconv.ParseInt(val, 10, 64)
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot convert %T to int64", v)
|
||||
}
|
||||
}
|
||||
|
||||
// ToString Converts any to string | 将any转换为string
|
||||
func ToString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
case []byte:
|
||||
return string(val)
|
||||
case int, int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%v", val)
|
||||
case bool:
|
||||
return strconv.FormatBool(val)
|
||||
default:
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBool Converts any to bool | 将any转换为bool
|
||||
func ToBool(v any) (bool, error) {
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return val, nil
|
||||
case string:
|
||||
return strconv.ParseBool(val)
|
||||
case int, int8, int16, int32, int64:
|
||||
return val != 0, nil
|
||||
default:
|
||||
return false, fmt.Errorf("cannot convert %T to bool", v)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Hash & Encoding | 哈希和编码 ============
|
||||
|
||||
// SHA256Hash Generates SHA256 hash of string | 生成字符串的SHA256哈希
|
||||
func SHA256Hash(s string) string {
|
||||
hash := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Base64Encode Encodes string to base64 | Base64编码
|
||||
func Base64Encode(s string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
// Base64Decode Decodes base64 string | Base64解码
|
||||
func Base64Decode(s string) (string, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Base64URLEncode Encodes string to URL-safe base64 | URL安全的Base64编码
|
||||
func Base64URLEncode(s string) string {
|
||||
return base64.URLEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
// Base64URLDecode Decodes URL-safe base64 string | URL安全的Base64解码
|
||||
func Base64URLDecode(s string) (string, error) {
|
||||
data, err := base64.URLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// ============ Validation | 验证 ============
|
||||
|
||||
// IsAlphanumeric Checks if string contains only alphanumeric characters | 检查是否只包含字母数字
|
||||
func IsAlphanumeric(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsNumeric Checks if string contains only numbers | 检查是否只包含数字
|
||||
func IsNumeric(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// HasLength Checks if string length is within range | 检查字符串长度是否在范围内
|
||||
func HasLength(s string, min, max int) bool {
|
||||
length := len(s)
|
||||
return length >= min && length <= max
|
||||
}
|
||||
|
||||
// ============ Slice Helpers | 切片辅助 ============
|
||||
|
||||
// InSlice Checks if value exists in slice | 检查值是否存在于切片中
|
||||
func InSlice[T comparable](slice []T, val T) bool {
|
||||
for _, item := range slice {
|
||||
if item == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UniqueSlice Removes duplicates from slice | 去除切片中的重复元素
|
||||
func UniqueSlice[T comparable](slice []T) []T {
|
||||
seen := make(map[T]bool, len(slice))
|
||||
result := make([]T, 0, len(slice))
|
||||
for _, item := range slice {
|
||||
if !seen[item] {
|
||||
seen[item] = true
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/click33/sa-token-go/core/banner"
|
||||
"github.com/click33/sa-token-go/core/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 1. 打印基础 Banner
|
||||
banner.Print()
|
||||
|
||||
// 2. 打印带完整配置的 Banner
|
||||
cfg := config.DefaultConfig()
|
||||
banner.PrintWithConfig(cfg)
|
||||
|
||||
// 3. 打印 JWT 配置的 Banner
|
||||
jwtCfg := &config.Config{
|
||||
TokenName: "jwt-token",
|
||||
Timeout: 86400, // 24小时
|
||||
ActiveTimeout: -1,
|
||||
IsConcurrent: true,
|
||||
IsShare: false,
|
||||
MaxLoginCount: 5,
|
||||
IsReadBody: false,
|
||||
IsReadHeader: true,
|
||||
IsReadCookie: true,
|
||||
TokenStyle: config.TokenStyleJWT,
|
||||
DataRefreshPeriod: -1,
|
||||
TokenSessionCheckLogin: true,
|
||||
AutoRenew: true,
|
||||
JwtSecretKey: "my-super-secret-key-123456",
|
||||
IsLog: true,
|
||||
IsPrintBanner: true,
|
||||
CookieConfig: &config.CookieConfig{
|
||||
Domain: "example.com",
|
||||
Path: "/api",
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
SameSite: config.SameSiteStrict,
|
||||
MaxAge: 7200,
|
||||
},
|
||||
}
|
||||
|
||||
banner.PrintWithConfig(jwtCfg)
|
||||
}
|
||||
@@ -19,7 +19,7 @@ var (
|
||||
|
||||
// item 存储项
|
||||
type item struct {
|
||||
value interface{}
|
||||
value any
|
||||
expiration int64 // 过期时间戳(0表示永不过期)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func NewStorageWithCleanupInterval(interval time.Duration) adapter.Storage {
|
||||
}
|
||||
|
||||
// Set 设置键值对
|
||||
func (s *Storage) Set(key string, value interface{}, expiration time.Duration) error {
|
||||
func (s *Storage) Set(key string, value any, expiration time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -72,7 +72,7 @@ func (s *Storage) Set(key string, value interface{}, expiration time.Duration) e
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
func (s *Storage) Get(key string) (interface{}, error) {
|
||||
func (s *Storage) Get(key string) (any, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
s.mu.RLock()
|
||||
@@ -93,11 +93,13 @@ func (s *Storage) Get(key string) (interface{}, error) {
|
||||
}
|
||||
|
||||
// Delete 删除键
|
||||
func (s *Storage) Delete(key string) error {
|
||||
func (s *Storage) Delete(keys ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.data, key)
|
||||
for _, key := range keys {
|
||||
delete(s.data, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -194,6 +196,17 @@ func (s *Storage) Clear() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping 检查存储可用性
|
||||
func (s *Storage) Ping() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.closed {
|
||||
return errors.New("storage is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭存储,停止清理协程
|
||||
func (s *Storage) Close() error {
|
||||
s.mu.Lock()
|
||||
|
||||
+20
-4
@@ -104,14 +104,14 @@ func (s *Storage) getKey(key string) string {
|
||||
}
|
||||
|
||||
// Set 设置键值对
|
||||
func (s *Storage) Set(key string, value interface{}, expiration time.Duration) error {
|
||||
func (s *Storage) Set(key string, value any, expiration time.Duration) error {
|
||||
ctx, cancel := s.withTimeout()
|
||||
defer cancel()
|
||||
return s.client.Set(ctx, s.getKey(key), value, expiration).Err()
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
func (s *Storage) Get(key string) (interface{}, error) {
|
||||
func (s *Storage) Get(key string) (any, error) {
|
||||
ctx, cancel := s.withTimeout()
|
||||
defer cancel()
|
||||
val, err := s.client.Get(ctx, s.getKey(key)).Result()
|
||||
@@ -125,10 +125,19 @@ func (s *Storage) Get(key string) (interface{}, error) {
|
||||
}
|
||||
|
||||
// Delete 删除键
|
||||
func (s *Storage) Delete(key string) error {
|
||||
func (s *Storage) Delete(keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := s.withTimeout()
|
||||
defer cancel()
|
||||
return s.client.Del(ctx, s.getKey(key)).Err()
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = s.getKey(key)
|
||||
}
|
||||
return s.client.Del(ctx, fullKeys...).Err()
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
@@ -215,6 +224,13 @@ func (s *Storage) Clear() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping 检查连接
|
||||
func (s *Storage) Ping() error {
|
||||
ctx, cancel := s.withTimeout()
|
||||
defer cancel()
|
||||
return s.client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (s *Storage) Close() error {
|
||||
return s.client.Close()
|
||||
|
||||
Reference in New Issue
Block a user