refactor: remove FINAL_REPORT.md and enhance error handling, configuration, and session management in core components

This commit is contained in:
Moling
2025-10-25 01:54:24 +08:00
parent 6cb6e9910d
commit 7603ce3cbf
21 changed files with 2329 additions and 684 deletions
-127
View File
@@ -1,127 +0,0 @@
# 🎉 Sa-Token-Go 项目完成报告
## ✅ 项目信息
**项目名称**: Sa-Token-Go
**版本**: v0.1.0
**作者**: click33
**仓库**: https://github.com/click33/sa-token-go
**完成日期**: 2025-10-13
---
## 🚀 核心功能
### 1. 超简洁API
```go
// 一行初始化
stputil.SetManager(core.NewBuilder().Storage(memory.NewStorage()).Build())
// 直接使用
stputil.Login(1000)
```
### 2. 注解装饰器
```go
r.GET("/public", sagin.Ignore(), handler)
r.GET("/user", sagin.CheckLogin(), handler)
r.GET("/admin", sagin.CheckPermission("admin"), handler)
```
### 3. 异步续签
- 性能提升 400%
- 响应延迟从 250ms → 50ms
- QPS从 2000 → 10000
### 4. 完整功能
40+核心方法,涵盖所有认证授权场景
---
## 📂 项目结构
```
sa-token-go/
├── core/ # 核心模块
│ ├── manager/ # 认证管理器(异步续签)
│ ├── builder/ # Builder构建器
│ ├── stputil/ # 全局工具类
│ └── ...
├── storage/
│ ├── memory/ # 内存存储
│ └── redis/ # Redis存储
├── integrations/
│ ├── gin/ # Gin集成(含注解)
│ ├── echo/ # Echo集成
│ ├── fiber/ # Fiber集成
│ └── chi/ # Chi集成
├── examples/
│ ├── quick-start/ # 快速开始
│ ├── annotation/ # 注解使用
│ └── gin/echo/fiber/chi # 框架集成
└── docs/
├── tutorial/ # 教程
├── guide/ # 使用指南
├── api/ # API文档
└── design/ # 设计文档
```
---
## 📊 项目统计
| 项目 | 数量 |
|------|------|
| Go源文件 | 31个 |
| 文档文件 | 10个 |
| 模块数量 | 13个 |
| 核心方法 | 40+ |
| 装饰器 | 5个 |
| 事件类型 | 8种 |
---
## 📚 文档体系
### 主文档
- README.md - 英文
- README_zh.md - 中文
### 详细文档
- docs/tutorial/ - 教程
- docs/guide/ - 使用指南
- docs/api/ - API文档
- docs/design/ - 设计文档
---
## 🎯 核心优势
1. **超简洁** - 一行初始化
2. **全局工具类** - 无需传递manager
3. **装饰器模式** - 类似Java注解
4. **异步续签** - 性能提升400%
5. **模块化** - 按需导入
6. **类型友好** - 支持多种类型
---
## 🚀 推送到GitHub
```bash
cd /Users/m1pro/go_project/sa-token-go
git init
git add .
git commit -m "feat: Sa-Token-Go v0.1.0
- 超简洁APIBuilder+StpUtil
- 注解装饰器:@SaCheckLogin等
- 异步续签:性能提升400%
- 完整文档:tutorial/guide/api/design"
git remote add origin https://github.com/click33/sa-token-go.git
git push -u origin main
```
---
**Sa-Token-Go v0.1.0 - 完成!** 🎉
+69 -8
View File
@@ -1,34 +1,95 @@
package adapter
// CookieOptions Cookie setting options | Cookie设置选项
type CookieOptions struct {
// Name Cookie name | Cookie名称
Name string
// Value Cookie value | Cookie值
Value string
// MaxAge Cookie expiration time in seconds, 0 means delete cookie, -1 means session cookie | 过期时间(秒),0表示删除cookie,-1表示会话cookie
MaxAge int
// Path Cookie path | 路径
Path string
// Domain Cookie domain | 域名
Domain string
// Secure Only effective under HTTPS | 是否只在HTTPS下生效
Secure bool
// HttpOnly Prevent JavaScript access | 是否禁止JS访问
HttpOnly bool
// SameSite SameSite attribute (Strict, Lax, None) | SameSite属性
SameSite string
}
// RequestContext defines request context interface for abstracting different web frameworks | 定义请求上下文接口,用于抽象不同Web框架的请求/响应
type RequestContext interface {
// ============== Request Methods | 请求方法 ==============
// GetHeader gets request header | 获取请求头
GetHeader(key string) string
// GetHeaders gets all request headers | 获取所有请求头
GetHeaders() map[string][]string
// GetQuery gets query parameter | 获取查询参数
GetQuery(key string) string
// GetQueryAll gets all query parameters | 获取所有查询参数
GetQueryAll() map[string][]string
// GetPostForm gets POST form parameter | 获取POST表单参数
GetPostForm(key string) string
// GetCookie gets cookie | 获取Cookie
GetCookie(key string) string
// SetHeader sets response header | 设置响应头
SetHeader(key, value string)
// SetCookie sets cookie | 设置Cookie
SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool)
// GetBody gets request body as bytes | 获取请求体字节数据
GetBody() ([]byte, error)
// GetClientIP gets client IP address | 获取客户端IP地址
GetClientIP() string
// GetMethod gets request method | 获取请求方法
// GetMethod gets request method (GET, POST, etc.) | 获取请求方法(GET、POST等)
GetMethod() string
// GetPath gets request path | 获取请求路径
GetPath() string
// GetURL gets full request URL | 获取完整请求URL
GetURL() string
// GetUserAgent gets User-Agent header | 获取User-Agent
GetUserAgent() string
// ============== Response Methods | 响应方法 ==============
// SetHeader sets response header | 设置响应头
SetHeader(key, value string)
// SetCookie sets cookie (legacy method for backward compatibility) | 设置Cookie(兼容旧版本的方法)
SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool)
// SetCookieWithOptions sets cookie with options | 使用选项设置Cookie
SetCookieWithOptions(options *CookieOptions)
// ============== Context Storage Methods | 上下文存储方法 ==============
// Set sets context value | 设置上下文值
Set(key string, value interface{})
Set(key string, value any)
// Get gets context value | 获取上下文值
Get(key string) (interface{}, bool)
Get(key string) (any, bool)
// GetString gets string value from context | 从上下文获取字符串值
GetString(key string) string
// MustGet gets context value, panics if not exists | 获取上下文值,不存在则panic
MustGet(key string) any
// ============== Utility Methods | 工具方法 ==============
// Abort aborts the request processing | 中止请求处理
Abort()
// IsAborted checks if the request is aborted | 检查请求是否已中止
IsAborted() bool
}
+17 -8
View File
@@ -4,27 +4,36 @@ import "time"
// Storage defines storage interface for Token and Session data | 定义存储接口,用于存储Token和Session数据
type Storage interface {
// ============== Basic Operations | 基本操作 ==============
// Set sets key-value pair with optional expiration time (0 means never expire) | 设置键值对,可选过期时间(0表示永不过期)
Set(key string, value interface{}, expiration time.Duration) error
Set(key string, value any, expiration time.Duration) error
// Get gets value by key | 获取键对应的值
Get(key string) (interface{}, error)
// Get gets value by key, returns nil if key doesn't exist | 获取键对应的值,键不存在时返回nil
Get(key string) (any, error)
// Delete deletes key | 删除键
Delete(key string) error
// Delete deletes one or more keys | 删除一个或多个
Delete(keys ...string) error
// Exists checks if key exists | 检查键是否存在
Exists(key string) bool
// Keys gets all keys matching pattern | 获取匹配模式的所有键
// ============== Key Management | 键管理 ==============
// Keys gets all keys matching pattern (e.g., "user:*") | 获取匹配模式的所有键(如:"user:*"
Keys(pattern string) ([]string, error)
// Expire sets expiration time for key | 设置键的过期时间
Expire(key string, expiration time.Duration) error
// TTL gets remaining time to live | 获取键的剩余生存时间
// TTL gets remaining time to live (-1 if no expiration, -2 if key doesn't exist) | 获取键的剩余生存时间(-1表示永不过期,-2表示键不存在)
TTL(key string) (time.Duration, error)
// Clear clears all data (for testing) | 清空所有数据(用于测试)
// ============== Utility Methods | 工具方法 ==============
// Clear clears all data (use with caution, mainly for testing) | 清空所有数据(谨慎使用,主要用于测试)
Clear() error
// Ping checks if storage is accessible | 检查存储是否可访问
Ping() error
}
+49 -39
View File
@@ -8,7 +8,7 @@ import (
)
// Version version number | 版本号
const Version = "0.1.0"
const Version = "0.1.1"
// Banner startup banner | 启动横幅
const Banner = `
@@ -18,9 +18,18 @@ const Banner = `
___/ / /_/ / / / / /_/ / ,< / __/ / / /_____/ /_/ / /_/ /
/____/\__,_/ /_/ \____/_/|_|\___/_/ /_/ \____/\____/
:: Sa-Token-Go :: (v%s)
:: Sa-Token-Go :: v%s
`
const (
boxWidth = 57
labelWidth = 16
neverExpire = "Never Expire"
noLimit = "No Limit"
configured = "*** (configured)"
secondsFormat = "%d seconds"
)
// Print prints startup banner | 打印启动横幅
func Print() {
fmt.Printf(Banner, Version)
@@ -29,6 +38,36 @@ func Print() {
fmt.Println()
}
// formatConfigLine formats a configuration line with proper padding | 格式化配置行
func formatConfigLine(label string, value any) string {
valueWidth := boxWidth - labelWidth - 5 // 57 - 16 - 5 = 36
valueStr := fmt.Sprintf("%v", value)
return fmt.Sprintf("│ %-*s: %-*s │\n", labelWidth, label, valueWidth, valueStr)
}
// formatTimeout formats timeout value (seconds or special text) | 格式化超时时间值
func formatTimeout(seconds int64) string {
if seconds > 0 {
// Also show human-readable format for large values
if seconds >= 86400 {
days := seconds / 86400
return fmt.Sprintf("%d seconds (%d days)", seconds, days)
}
return fmt.Sprintf(secondsFormat, seconds)
} else if seconds == 0 {
return neverExpire
}
return noLimit
}
// formatCount formats count value (number or "No Limit") | 格式化数量值
func formatCount(count int) string {
if count > 0 {
return fmt.Sprintf("%d", count)
}
return noLimit
}
// PrintWithConfig prints startup banner with full configuration | 打印启动横幅和完整配置信息
func PrintWithConfig(cfg *config.Config) {
Print()
@@ -38,46 +77,17 @@ func PrintWithConfig(cfg *config.Config) {
fmt.Println("├─────────────────────────────────────────────────────────┤")
// Token configuration | Token 配置
fmt.Printf("Token Name : %-35s │\n", cfg.TokenName)
fmt.Printf("Token Style : %-35s │\n", cfg.TokenStyle)
if cfg.Timeout > 0 {
fmt.Printf("│ Token Timeout : %-25d seconds │\n", cfg.Timeout)
} else {
fmt.Printf("│ Token Timeout : %-35s │\n", "Never Expire")
}
if cfg.ActiveTimeout > 0 {
fmt.Printf("│ Active Timeout : %-25d seconds │\n", cfg.ActiveTimeout)
} else {
fmt.Printf("│ Active Timeout : %-35s │\n", "No Limit")
}
fmt.Print(formatConfigLine("Token Name", cfg.TokenName))
fmt.Print(formatConfigLine("Token Style", cfg.TokenStyle))
fmt.Print(formatConfigLine("Token Timeout", formatTimeout(cfg.Timeout)))
fmt.Print(formatConfigLine("Active Timeout", formatTimeout(cfg.ActiveTimeout)))
// Login configuration | 登录配置
fmt.Printf("│ Auto Renew : %-35v │\n", cfg.AutoRenew)
fmt.Printf("│ Concurrent : %-35v │\n", cfg.IsConcurrent)
fmt.Printf("│ Share Token : %-35v │\n", cfg.IsShare)
if cfg.MaxLoginCount > 0 {
fmt.Printf("│ Max Login Count : %-35d │\n", cfg.MaxLoginCount)
} else {
fmt.Printf("│ Max Login Count : %-35s │\n", "No Limit")
}
// Token read source | Token 读取位置
fmt.Println("├─────────────────────────────────────────────────────────┤")
fmt.Printf("│ Read From Header: %-35v │\n", cfg.IsReadHeader)
fmt.Printf("│ Read From Cookie: %-35v │\n", cfg.IsReadCookie)
fmt.Printf("│ Read From Body : %-35v │\n", cfg.IsReadBody)
// Other settings | 其他设置
fmt.Println("├─────────────────────────────────────────────────────────┤")
if cfg.TokenStyle == config.TokenStyleJWT && cfg.JwtSecretKey != "" {
fmt.Printf("│ JWT Secret : %-35s │\n", "*** (configured)")
}
fmt.Printf("│ Logging : %-35v │\n", cfg.IsLog)
fmt.Print(formatConfigLine("Auto Renew", cfg.AutoRenew))
fmt.Print(formatConfigLine("Concurrent", cfg.IsConcurrent))
fmt.Print(formatConfigLine("Share Token", cfg.IsShare))
fmt.Print(formatConfigLine("Max Login Count", formatCount(cfg.MaxLoginCount)))
fmt.Println("└─────────────────────────────────────────────────────────┘")
fmt.Println()
+371
View File
@@ -0,0 +1,371 @@
package banner
import (
"bytes"
"io"
"os"
"runtime"
"strings"
"testing"
"github.com/click33/sa-token-go/core/config"
)
// captureOutput captures stdout output for testing
func captureOutput(f func()) string {
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
f()
w.Close()
os.Stdout = old
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func TestPrint(t *testing.T) {
output := captureOutput(func() {
Print()
})
// Check if output contains expected elements
if !strings.Contains(output, "Sa-Token-Go") {
t.Error("Output should contain 'Sa-Token-Go'")
}
if !strings.Contains(output, Version) {
t.Errorf("Output should contain version %s", Version)
}
if !strings.Contains(output, "Go Version") {
t.Error("Output should contain 'Go Version'")
}
if !strings.Contains(output, runtime.Version()) {
t.Errorf("Output should contain Go version %s", runtime.Version())
}
if !strings.Contains(output, "GOOS/GOARCH") {
t.Error("Output should contain 'GOOS/GOARCH'")
}
expectedOS := runtime.GOOS + "/" + runtime.GOARCH
if !strings.Contains(output, expectedOS) {
t.Errorf("Output should contain OS/ARCH %s", expectedOS)
}
}
func TestFormatTimeout(t *testing.T) {
tests := []struct {
name string
seconds int64
expected string
}{
{
name: "Positive seconds less than a day",
seconds: 3600,
expected: "3600 seconds",
},
{
name: "Exactly one day",
seconds: 86400,
expected: "86400 seconds (1 days)",
},
{
name: "Multiple days",
seconds: 259200, // 3 days
expected: "259200 seconds (3 days)",
},
{
name: "30 days",
seconds: 2592000,
expected: "2592000 seconds (30 days)",
},
{
name: "Zero means never expire",
seconds: 0,
expected: neverExpire,
},
{
name: "Negative means no limit",
seconds: -1,
expected: noLimit,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatTimeout(tt.seconds)
if result != tt.expected {
t.Errorf("formatTimeout(%d) = %s, want %s", tt.seconds, result, tt.expected)
}
})
}
}
func TestFormatCount(t *testing.T) {
tests := []struct {
name string
count int
expected string
}{
{
name: "Positive count",
count: 12,
expected: "12",
},
{
name: "Zero means no limit",
count: 0,
expected: noLimit,
},
{
name: "Negative means no limit",
count: -1,
expected: noLimit,
},
{
name: "Large count",
count: 9999,
expected: "9999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatCount(tt.count)
if result != tt.expected {
t.Errorf("formatCount(%d) = %s, want %s", tt.count, result, tt.expected)
}
})
}
}
func TestFormatConfigLine(t *testing.T) {
tests := []struct {
name string
label string
value any
contains []string
}{
{
name: "String value",
label: "Token Name",
value: "sa-token",
contains: []string{
"Token Name",
"sa-token",
"│",
},
},
{
name: "Boolean value",
label: "Auto Renew",
value: true,
contains: []string{
"Auto Renew",
"true",
},
},
{
name: "Integer value",
label: "Max Count",
value: 12,
contains: []string{
"Max Count",
"12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatConfigLine(tt.label, tt.value)
for _, s := range tt.contains {
if !strings.Contains(result, s) {
t.Errorf("formatConfigLine(%s, %v) should contain %s, got: %s", tt.label, tt.value, s, result)
}
}
})
}
}
func TestPrintWithConfig(t *testing.T) {
tests := []struct {
name string
config *config.Config
contains []string
}{
{
name: "Default configuration",
config: config.DefaultConfig(),
contains: []string{
"Configuration",
"Token Name",
"sa-token",
"Token Style",
"uuid",
"Token Timeout",
"30 days",
"Auto Renew",
"Concurrent",
"Share Token",
"Max Login Count",
"Read From Header",
"Read From Cookie",
"Read From Body",
"Logging",
},
},
{
name: "JWT configuration",
config: &config.Config{
TokenName: "jwt-token",
Timeout: 3600,
ActiveTimeout: -1,
IsConcurrent: true,
IsShare: false,
MaxLoginCount: 5,
IsReadBody: false,
IsReadHeader: true,
IsReadCookie: false,
TokenStyle: config.TokenStyleJWT,
AutoRenew: true,
JwtSecretKey: "my-secret-key",
IsLog: true,
CookieConfig: &config.CookieConfig{
Path: "/api",
SameSite: config.SameSiteLax,
HttpOnly: true,
Secure: true,
},
},
contains: []string{
"jwt-token",
"jwt",
"3600 seconds",
"JWT Secret",
"*** (configured)",
"Cookie Path",
"/api",
"Cookie SameSite",
"Cookie HttpOnly",
"Cookie Secure",
},
},
{
name: "Never expire configuration",
config: &config.Config{
TokenName: "never-token",
Timeout: 0,
ActiveTimeout: -1,
IsConcurrent: false,
IsShare: true,
MaxLoginCount: -1,
TokenStyle: config.TokenStyleUUID,
CookieConfig: &config.CookieConfig{},
},
contains: []string{
"Never Expire",
"No Limit",
},
},
{
name: "JWT without secret key",
config: &config.Config{
TokenName: "jwt-token",
TokenStyle: config.TokenStyleJWT,
JwtSecretKey: "",
CookieConfig: &config.CookieConfig{},
},
contains: []string{
"JWT Secret",
"Not Set",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output := captureOutput(func() {
PrintWithConfig(tt.config)
})
for _, s := range tt.contains {
if !strings.Contains(output, s) {
t.Errorf("PrintWithConfig() output should contain '%s'\nGot output:\n%s", s, output)
}
}
// Check for box drawing characters
if !strings.Contains(output, "┌") || !strings.Contains(output, "└") {
t.Error("Output should contain box drawing characters")
}
})
}
}
func TestPrintWithConfigNilCookie(t *testing.T) {
cfg := &config.Config{
TokenName: "test-token",
TokenStyle: config.TokenStyleSimple,
CookieConfig: nil, // nil cookie config
}
output := captureOutput(func() {
PrintWithConfig(cfg)
})
// Should not panic and should not contain cookie configuration
if strings.Contains(output, "Cookie Path") {
t.Error("Output should not contain Cookie configuration when CookieConfig is nil")
}
}
func BenchmarkPrint(b *testing.B) {
// Redirect output to discard
old := os.Stdout
os.Stdout, _ = os.Open(os.DevNull)
defer func() { os.Stdout = old }()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Print()
}
}
func BenchmarkPrintWithConfig(b *testing.B) {
cfg := config.DefaultConfig()
// Redirect output to discard
old := os.Stdout
os.Stdout, _ = os.Open(os.DevNull)
defer func() { os.Stdout = old }()
b.ResetTimer()
for i := 0; i < b.N; i++ {
PrintWithConfig(cfg)
}
}
func BenchmarkFormatTimeout(b *testing.B) {
for i := 0; i < b.N; i++ {
formatTimeout(2592000)
}
}
func BenchmarkFormatCount(b *testing.B) {
for i := 0; i < b.N; i++ {
formatCount(12)
}
}
func BenchmarkFormatConfigLine(b *testing.B) {
for i := 0; i < b.N; i++ {
formatConfigLine("Token Name", "sa-token")
}
}
+165 -41
View File
@@ -1,6 +1,7 @@
package builder
import (
"fmt"
"time"
"github.com/click33/sa-token-go/core/adapter"
@@ -11,39 +12,52 @@ import (
// Builder Sa-Token builder for fluent configuration | Sa-Token构建器,用于流式配置
type Builder struct {
storage adapter.Storage
tokenName string
timeout int64
activeTimeout int64
isConcurrent bool
isShare bool
maxLoginCount int
tokenStyle config.TokenStyle
autoRenew bool
jwtSecretKey string
isLog bool
isPrintBanner bool
isReadBody bool
isReadHeader bool
isReadCookie bool
storage adapter.Storage
tokenName string
timeout int64
activeTimeout int64
isConcurrent bool
isShare bool
maxLoginCount int
tokenStyle config.TokenStyle
autoRenew bool
jwtSecretKey string
isLog bool
isPrintBanner bool
isReadBody bool
isReadHeader bool
isReadCookie bool
dataRefreshPeriod int64
tokenSessionCheckLogin bool
cookieConfig *config.CookieConfig
}
// NewBuilder creates a new builder | 创建新的构建器
// NewBuilder creates a new builder with default configuration | 创建新的构建器(使用默认配置)
func NewBuilder() *Builder {
return &Builder{
tokenName: "satoken",
timeout: 2592000, // 30 days | 30天
activeTimeout: -1,
isConcurrent: true,
isShare: true,
maxLoginCount: 12,
tokenStyle: config.TokenStyleUUID,
autoRenew: true,
isLog: false,
isPrintBanner: true, // Print banner by default | 默认打印 Banner
isReadBody: false, // Don't read from body by default | 默认不从 Body 读取
isReadHeader: true, // Read from header by default | 默认从 Header 读取
isReadCookie: false, // Don't read from cookie by default | 默认不从 Cookie 读取
tokenName: config.DefaultTokenName,
timeout: config.DefaultTimeout,
activeTimeout: config.NoLimit,
isConcurrent: true,
isShare: true,
maxLoginCount: config.DefaultMaxLoginCount,
tokenStyle: config.TokenStyleUUID,
autoRenew: true,
isLog: false,
isPrintBanner: true,
isReadBody: false,
isReadHeader: true,
isReadCookie: false,
dataRefreshPeriod: config.NoLimit,
tokenSessionCheckLogin: true,
cookieConfig: &config.CookieConfig{
Domain: "",
Path: config.DefaultCookiePath,
Secure: false,
HttpOnly: true,
SameSite: config.SameSiteLax,
MaxAge: 0,
},
}
}
@@ -143,10 +157,122 @@ func (b *Builder) IsReadCookie(isRead bool) *Builder {
return b
}
// DataRefreshPeriod sets data refresh period | 设置数据刷新周期
func (b *Builder) DataRefreshPeriod(seconds int64) *Builder {
b.dataRefreshPeriod = seconds
return b
}
// TokenSessionCheckLogin sets whether to check token session on login | 设置登录时是否检查Token会话
func (b *Builder) TokenSessionCheckLogin(check bool) *Builder {
b.tokenSessionCheckLogin = check
return b
}
// CookieDomain sets cookie domain | 设置Cookie域名
func (b *Builder) CookieDomain(domain string) *Builder {
if b.cookieConfig == nil {
b.cookieConfig = &config.CookieConfig{}
}
b.cookieConfig.Domain = domain
return b
}
// CookiePath sets cookie path | 设置Cookie路径
func (b *Builder) CookiePath(path string) *Builder {
if b.cookieConfig == nil {
b.cookieConfig = &config.CookieConfig{}
}
b.cookieConfig.Path = path
return b
}
// CookieSecure sets cookie secure flag | 设置Cookie的Secure标志
func (b *Builder) CookieSecure(secure bool) *Builder {
if b.cookieConfig == nil {
b.cookieConfig = &config.CookieConfig{}
}
b.cookieConfig.Secure = secure
return b
}
// CookieHttpOnly sets cookie httpOnly flag | 设置Cookie的HttpOnly标志
func (b *Builder) CookieHttpOnly(httpOnly bool) *Builder {
if b.cookieConfig == nil {
b.cookieConfig = &config.CookieConfig{}
}
b.cookieConfig.HttpOnly = httpOnly
return b
}
// CookieSameSite sets cookie sameSite attribute | 设置Cookie的SameSite属性
func (b *Builder) CookieSameSite(sameSite config.SameSiteMode) *Builder {
if b.cookieConfig == nil {
b.cookieConfig = &config.CookieConfig{}
}
b.cookieConfig.SameSite = sameSite
return b
}
// CookieMaxAge sets cookie max age | 设置Cookie的最大年龄
func (b *Builder) CookieMaxAge(maxAge int) *Builder {
if b.cookieConfig == nil {
b.cookieConfig = &config.CookieConfig{}
}
b.cookieConfig.MaxAge = maxAge
return b
}
// CookieConfig sets complete cookie configuration | 设置完整的Cookie配置
func (b *Builder) CookieConfig(cfg *config.CookieConfig) *Builder {
b.cookieConfig = cfg
return b
}
// NeverExpire sets token to never expire | 设置Token永不过期
func (b *Builder) NeverExpire() *Builder {
b.timeout = config.NoLimit
return b
}
// NoActiveTimeout disables active timeout | 禁用活跃超时
func (b *Builder) NoActiveTimeout() *Builder {
b.activeTimeout = config.NoLimit
return b
}
// UnlimitedLogin allows unlimited concurrent logins | 允许无限并发登录
func (b *Builder) UnlimitedLogin() *Builder {
b.maxLoginCount = config.NoLimit
return b
}
// Validate validates the builder configuration | 验证构建器配置
func (b *Builder) Validate() error {
if b.storage == nil {
return fmt.Errorf("storage is required, please call Storage() method")
}
if b.tokenName == "" {
return fmt.Errorf("tokenName cannot be empty")
}
if b.tokenStyle == config.TokenStyleJWT && b.jwtSecretKey == "" {
return fmt.Errorf("jwtSecretKey is required when TokenStyle is JWT")
}
if !b.isReadHeader && !b.isReadCookie && !b.isReadBody {
return fmt.Errorf("at least one of IsReadHeader, IsReadCookie, or IsReadBody must be true")
}
return nil
}
// Build builds Manager and prints startup banner | 构建Manager并打印启动Banner
func (b *Builder) Build() *manager.Manager {
if b.storage == nil {
panic("storage is required, please call Storage() method")
// Validate configuration | 验证配置
if err := b.Validate(); err != nil {
panic(fmt.Sprintf("invalid configuration: %v", err))
}
cfg := &config.Config{
@@ -160,20 +286,13 @@ func (b *Builder) Build() *manager.Manager {
IsReadHeader: b.isReadHeader,
IsReadCookie: b.isReadCookie,
TokenStyle: b.tokenStyle,
DataRefreshPeriod: -1,
TokenSessionCheckLogin: true,
DataRefreshPeriod: b.dataRefreshPeriod,
TokenSessionCheckLogin: b.tokenSessionCheckLogin,
AutoRenew: b.autoRenew,
JwtSecretKey: b.jwtSecretKey,
IsLog: b.isLog,
IsPrintBanner: b.isPrintBanner,
CookieConfig: &config.CookieConfig{
Domain: "",
Path: "/",
Secure: false,
HttpOnly: true,
SameSite: "Lax",
MaxAge: 0,
},
CookieConfig: b.cookieConfig,
}
// Print startup banner with full configuration | 打印启动Banner和完整配置
@@ -189,3 +308,8 @@ func (b *Builder) Build() *manager.Manager {
return mgr
}
// MustBuild builds Manager and panics if validation fails | 构建Manager,验证失败时panic
func (b *Builder) MustBuild() *manager.Manager {
return b.Build()
}
+141 -18
View File
@@ -1,5 +1,7 @@
package config
import "fmt"
// TokenStyle Token generation style | Token生成风格
type TokenStyle string
@@ -24,6 +26,39 @@ const (
TokenStyleTik TokenStyle = "tik"
)
// SameSiteMode Cookie SameSite attribute values | Cookie的SameSite属性值
type SameSiteMode string
const (
// SameSiteStrict Strict mode | 严格模式
SameSiteStrict SameSiteMode = "Strict"
// SameSiteLax Lax mode | 宽松模式
SameSiteLax SameSiteMode = "Lax"
// SameSiteNone None mode | 无限制模式
SameSiteNone SameSiteMode = "None"
)
// Default configuration constants | 默认配置常量
const (
DefaultTokenName = "satoken"
DefaultTimeout = 2592000 // 30 days in seconds | 30天(秒)
DefaultMaxLoginCount = 12 // Maximum concurrent logins | 最大并发登录数
DefaultCookiePath = "/"
NoLimit = -1 // No limit flag | 不限制标志
)
// IsValid checks if the TokenStyle is valid | 检查TokenStyle是否有效
func (ts TokenStyle) IsValid() bool {
switch ts {
case TokenStyleUUID, TokenStyleSimple, TokenStyleRandom32,
TokenStyleRandom64, TokenStyleRandom128, TokenStyleJWT,
TokenStyleHash, TokenStyleTimestamp, TokenStyleTik:
return true
default:
return false
}
}
// Config Sa-Token configuration | Sa-Token配置
type Config struct {
// TokenName Token name (also used as Cookie name) | Token名称(同时也是Cookie名称)
@@ -93,7 +128,7 @@ type CookieConfig struct {
HttpOnly bool
// SameSite SameSite attribute (Strict, Lax, None) | SameSite属性(Strict、Lax、None
SameSite string
SameSite SameSiteMode
// MaxAge Cookie expiration time in seconds | 过期时间(单位:秒)
MaxAge int
@@ -102,33 +137,73 @@ type CookieConfig struct {
// DefaultConfig Returns default configuration | 返回默认配置
func DefaultConfig() *Config {
return &Config{
TokenName: "sa-token",
Timeout: 2592000, // 30 days | 30天
ActiveTimeout: -1, // No limit | 不限制
IsConcurrent: true, // Allow concurrent login | 允许并发登录
IsShare: true, // Share Token | 共享Token
MaxLoginCount: 12, // Max 12 logins | 最多12个
IsReadBody: false, // Don't read from Body (default) | 不从Body读取(默认)
IsReadHeader: true, // Read from Header (recommended) | 从Header读取(推荐)
IsReadCookie: false, // Don't read from Cookie (default) | 不从Cookie读取(默认)
TokenName: DefaultTokenName,
Timeout: DefaultTimeout,
ActiveTimeout: NoLimit,
IsConcurrent: true,
IsShare: true,
MaxLoginCount: DefaultMaxLoginCount,
IsReadBody: false,
IsReadHeader: true,
IsReadCookie: false,
TokenStyle: TokenStyleUUID,
DataRefreshPeriod: -1, // No auto-refresh | 不自动续签
TokenSessionCheckLogin: true, // Check on login | 登录时检查
AutoRenew: true, // Auto-renew | 自动续期
JwtSecretKey: "", // Empty by default | 默认空
IsLog: false, // No logging | 不输出日志
IsPrintBanner: true, // Print startup banner | 打印启动 Banner
DataRefreshPeriod: NoLimit,
TokenSessionCheckLogin: true,
AutoRenew: true,
JwtSecretKey: "",
IsLog: false,
IsPrintBanner: true,
CookieConfig: &CookieConfig{
Domain: "",
Path: "/",
Path: DefaultCookiePath,
Secure: false,
HttpOnly: true,
SameSite: "Lax",
SameSite: SameSiteLax,
MaxAge: 0,
},
}
}
// Validate validates the configuration | 验证配置是否合理
func (c *Config) Validate() error {
// Check TokenName
if c.TokenName == "" {
return fmt.Errorf("TokenName cannot be empty")
}
// Check TokenStyle
if !c.TokenStyle.IsValid() {
return fmt.Errorf("invalid TokenStyle: %s", c.TokenStyle)
}
// Check JWT secret key when using JWT style
if c.TokenStyle == TokenStyleJWT && c.JwtSecretKey == "" {
return fmt.Errorf("JwtSecretKey is required when TokenStyle is JWT")
}
// Check Timeout
if c.Timeout < NoLimit {
return fmt.Errorf("Timeout must be >= -1, got: %d", c.Timeout)
}
// Check ActiveTimeout
if c.ActiveTimeout < NoLimit {
return fmt.Errorf("ActiveTimeout must be >= -1, got: %d", c.ActiveTimeout)
}
// Check MaxLoginCount
if c.MaxLoginCount < NoLimit {
return fmt.Errorf("MaxLoginCount must be >= -1, got: %d", c.MaxLoginCount)
}
// Check if at least one read source is enabled
if !c.IsReadHeader && !c.IsReadCookie && !c.IsReadBody {
return fmt.Errorf("at least one of IsReadHeader, IsReadCookie, or IsReadBody must be true")
}
return nil
}
// Clone Clone configuration | 克隆配置
func (c *Config) Clone() *Config {
newConfig := *c
@@ -169,12 +244,48 @@ func (c *Config) SetIsShare(isShare bool) *Config {
return c
}
// SetMaxLoginCount Set maximum login count | 设置最大登录数量
func (c *Config) SetMaxLoginCount(count int) *Config {
c.MaxLoginCount = count
return c
}
// SetIsReadBody Set whether to read Token from body | 设置是否从请求体读取Token
func (c *Config) SetIsReadBody(isReadBody bool) *Config {
c.IsReadBody = isReadBody
return c
}
// SetIsReadHeader Set whether to read Token from header | 设置是否从Header读取Token
func (c *Config) SetIsReadHeader(isReadHeader bool) *Config {
c.IsReadHeader = isReadHeader
return c
}
// SetIsReadCookie Set whether to read Token from cookie | 设置是否从Cookie读取Token
func (c *Config) SetIsReadCookie(isReadCookie bool) *Config {
c.IsReadCookie = isReadCookie
return c
}
// SetTokenStyle Set Token generation style | 设置Token风格
func (c *Config) SetTokenStyle(style TokenStyle) *Config {
c.TokenStyle = style
return c
}
// SetDataRefreshPeriod Set data refresh period | 设置数据刷新周期
func (c *Config) SetDataRefreshPeriod(period int64) *Config {
c.DataRefreshPeriod = period
return c
}
// SetTokenSessionCheckLogin Set whether to check token session on login | 设置登录时是否检查token会话
func (c *Config) SetTokenSessionCheckLogin(check bool) *Config {
c.TokenSessionCheckLogin = check
return c
}
// SetJwtSecretKey Set JWT secret key | 设置JWT密钥
func (c *Config) SetJwtSecretKey(key string) *Config {
c.JwtSecretKey = key
@@ -192,3 +303,15 @@ func (c *Config) SetIsLog(isLog bool) *Config {
c.IsLog = isLog
return c
}
// SetIsPrintBanner Set whether to print banner | 设置是否打印Banner
func (c *Config) SetIsPrintBanner(isPrint bool) *Config {
c.IsPrintBanner = isPrint
return c
}
// SetCookieConfig Set cookie configuration | 设置Cookie配置
func (c *Config) SetCookieConfig(cookieConfig *CookieConfig) *Config {
c.CookieConfig = cookieConfig
return c
}
+31 -13
View File
@@ -1,10 +1,17 @@
package context
import (
"strings"
"github.com/click33/sa-token-go/core/adapter"
"github.com/click33/sa-token-go/core/manager"
)
const (
bearerPrefix = "Bearer "
authHeader = "Authorization"
)
// SaTokenContext Sa-Token context for current request | Sa-Token上下文,用于当前请求
type SaTokenContext struct {
ctx adapter.RequestContext
@@ -19,38 +26,49 @@ func NewContext(ctx adapter.RequestContext, mgr *manager.Manager) *SaTokenContex
}
}
// extractBearerToken 从 Authorization 头中提取 Bearer Token
func extractBearerToken(auth string) string {
auth = strings.TrimSpace(auth)
if auth == "" {
return ""
}
// 支持大小写不敏感的 Bearer 前缀
if len(auth) > 7 && strings.EqualFold(auth[:7], bearerPrefix) {
return strings.TrimSpace(auth[7:])
}
return auth
}
// GetTokenValue gets token value from current request | 获取当前请求的Token值
func (c *SaTokenContext) GetTokenValue() string {
cfg := c.manager.GetConfig()
// 1. 尝试从Header获取
if cfg.IsReadHeader {
token := c.ctx.GetHeader(cfg.TokenName)
if token != "" {
// 从自定义 token 名称的 Header 获取
if token := strings.TrimSpace(c.ctx.GetHeader(cfg.TokenName)); token != "" {
return token
}
// 也尝试从Authorization头获取
auth := c.ctx.GetHeader("Authorization")
if auth != "" {
// 移除 "Bearer " 前缀
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
// 从 Authorization 头获取
if auth := c.ctx.GetHeader(authHeader); auth != "" {
if token := extractBearerToken(auth); token != "" {
return token
}
return auth
}
}
// 2. 尝试从Cookie获取
if cfg.IsReadCookie {
token := c.ctx.GetCookie(cfg.TokenName)
if token != "" {
if token := strings.TrimSpace(c.ctx.GetCookie(cfg.TokenName)); token != "" {
return token
}
}
// 3. 尝试从Query参数获取
token := c.ctx.GetQuery(cfg.TokenName)
if token != "" {
if token := strings.TrimSpace(c.ctx.GetQuery(cfg.TokenName)); token != "" {
return token
}
+127 -29
View File
@@ -1,10 +1,15 @@
package core
import "fmt"
import (
"errors"
"fmt"
)
// Common error definitions for better error handling and internationalization support
// 常见错误定义,用于更好的错误处理和国际化支持
// ============ Authentication Errors | 认证错误 ============
var (
// ErrNotLogin indicates the user is not logged in | 用户未登录错误
ErrNotLogin = fmt.Errorf("authentication required: user not logged in")
@@ -15,20 +20,38 @@ var (
// ErrTokenExpired indicates the token has expired | Token已过期
ErrTokenExpired = fmt.Errorf("token expired: please login again to get a new token")
// ErrAccountDisabled indicates the account has been disabled or banned | 账号已被禁用
ErrAccountDisabled = fmt.Errorf("account disabled: this account has been temporarily or permanently disabled")
// ErrInvalidLoginID indicates the login ID is invalid | 登录ID无效
ErrInvalidLoginID = fmt.Errorf("invalid login ID: the login identifier cannot be empty")
// ErrInvalidDevice indicates the device identifier is invalid | 设备标识无效
ErrInvalidDevice = fmt.Errorf("invalid device: the device identifier is not valid")
)
// ============ Authorization Errors | 授权错误 ============
var (
// ErrPermissionDenied indicates insufficient permissions | 权限不足
ErrPermissionDenied = fmt.Errorf("permission denied: you don't have the required permission")
// ErrRoleDenied indicates insufficient role | 角色权限不足
ErrRoleDenied = fmt.Errorf("role denied: you don't have the required role")
)
// ErrSessionNotFound indicates the session doesn't exist | Session不存在
ErrSessionNotFound = fmt.Errorf("session not found: the session may have expired or been deleted")
// ============ Account Errors | 账号错误 ============
var (
// ErrAccountDisabled indicates the account has been disabled or banned | 账号已被禁用
ErrAccountDisabled = fmt.Errorf("account disabled: this account has been temporarily or permanently disabled")
// ErrAccountNotFound indicates the account doesn't exist | 账号不存在
ErrAccountNotFound = fmt.Errorf("account not found: no account associated with this identifier")
)
// ============ Session Errors | 会话错误 ============
var (
// ErrSessionNotFound indicates the session doesn't exist | Session不存在
ErrSessionNotFound = fmt.Errorf("session not found: the session may have expired or been deleted")
// ErrKickedOut indicates the user has been kicked out | 用户已被踢下线
ErrKickedOut = fmt.Errorf("kicked out: this session has been forcibly terminated")
@@ -38,26 +61,26 @@ var (
// ErrMaxLoginCount indicates maximum concurrent login limit reached | 达到最大登录数量限制
ErrMaxLoginCount = fmt.Errorf("max login limit: maximum number of concurrent logins reached")
// ErrStorageUnavailable indicates the storage backend is unavailable | 存储后端不可用
ErrStorageUnavailable = fmt.Errorf("storage unavailable: unable to connect to storage backend")
// ErrInvalidLoginID indicates the login ID is invalid | 登录ID无效
ErrInvalidLoginID = fmt.Errorf("invalid login ID: the login identifier cannot be empty")
// ErrInvalidDevice indicates the device identifier is invalid | 设备标识无效
ErrInvalidDevice = fmt.Errorf("invalid device: the device identifier is not valid")
)
// SaTokenError represents a custom error with error code and context | 自定义错误类型,包含错误码和上下文信息
// ============ System Errors | 系统错误 ============
var (
// ErrStorageUnavailable indicates the storage backend is unavailable | 存储后端不可用
ErrStorageUnavailable = fmt.Errorf("storage unavailable: unable to connect to storage backend")
)
// ============ Custom Error Type | 自定义错误类型 ============
// SaTokenError Represents a custom error with error code and context | 自定义错误类型,包含错误码和上下文信息
type SaTokenError struct {
Code int // Error code for programmatic handling | 错误码,用于程序化处理
Message string // Human-readable error message | 可读的错误消息
Err error // Underlying error (if any) | 底层错误(如果有)
Context map[string]interface{} // Additional context information | 额外的上下文信息
Code int // Error code for programmatic handling | 错误码,用于程序化处理
Message string // Human-readable error message | 可读的错误消息
Err error // Underlying error (if any) | 底层错误(如果有)
Context map[string]any // Additional context information | 额外的上下文信息
}
// Error implements the error interface | 实现 error 接口
// Error Implements the error interface | 实现 error 接口
func (e *SaTokenError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s (code: %d): %v", e.Message, e.Code, e.Err)
@@ -65,32 +88,52 @@ func (e *SaTokenError) Error() string {
return fmt.Sprintf("%s (code: %d)", e.Message, e.Code)
}
// Unwrap implements the unwrap interface for error chains | 实现 unwrap 接口,支持错误链
// Unwrap Implements the unwrap interface for error chains | 实现 unwrap 接口,支持错误链
func (e *SaTokenError) Unwrap() error {
return e.Err
}
// WithContext adds context information to the error | 为错误添加上下文信息
func (e *SaTokenError) WithContext(key string, value interface{}) *SaTokenError {
// WithContext Adds context information to the error | 为错误添加上下文信息
func (e *SaTokenError) WithContext(key string, value any) *SaTokenError {
if e.Context == nil {
e.Context = make(map[string]interface{})
e.Context = make(map[string]any)
}
e.Context[key] = value
return e
}
// NewError creates a new Sa-Token error | 创建新的 Sa-Token 错误
// GetContext Gets context value | 获取上下文值
func (e *SaTokenError) GetContext(key string) (any, bool) {
if e.Context == nil {
return nil, false
}
val, exists := e.Context[key]
return val, exists
}
// Is Implements errors.Is for error comparison | 实现 errors.Is 进行错误比较
func (e *SaTokenError) Is(target error) bool {
t, ok := target.(*SaTokenError)
if !ok {
return false
}
return e.Code == t.Code
}
// ============ Error Constructors | 错误构造函数 ============
// NewError Creates a new Sa-Token error | 创建新的 Sa-Token 错误
func NewError(code int, message string, err error) *SaTokenError {
return &SaTokenError{
Code: code,
Message: message,
Err: err,
Context: make(map[string]interface{}),
Context: make(map[string]any),
}
}
// NewErrorWithContext creates a new Sa-Token error with context | 创建带上下文的 Sa-Token 错误
func NewErrorWithContext(code int, message string, err error, context map[string]interface{}) *SaTokenError {
// NewErrorWithContext Creates a new Sa-Token error with context | 创建带上下文的 Sa-Token 错误
func NewErrorWithContext(code int, message string, err error, context map[string]any) *SaTokenError {
return &SaTokenError{
Code: code,
Message: message,
@@ -99,7 +142,62 @@ func NewErrorWithContext(code int, message string, err error, context map[string
}
}
// Error code definitions | 错误码定义
// NewNotLoginError Creates a not login error | 创建未登录错误
func NewNotLoginError() *SaTokenError {
return NewError(CodeNotLogin, "user not logged in", ErrNotLogin)
}
// NewPermissionDeniedError Creates a permission denied error | 创建权限拒绝错误
func NewPermissionDeniedError(permission string) *SaTokenError {
return NewError(CodePermissionDenied, "permission denied", ErrPermissionDenied).
WithContext("permission", permission)
}
// NewRoleDeniedError Creates a role denied error | 创建角色拒绝错误
func NewRoleDeniedError(role string) *SaTokenError {
return NewError(CodePermissionDenied, "role denied", ErrRoleDenied).
WithContext("role", role)
}
// NewAccountDisabledError Creates an account disabled error | 创建账号禁用错误
func NewAccountDisabledError(loginID string) *SaTokenError {
return NewError(CodeAccountDisabled, "account disabled", ErrAccountDisabled).
WithContext("loginID", loginID)
}
// ============ Error Checking Helpers | 错误检查辅助函数 ============
// IsNotLoginError Checks if error is a not login error | 检查是否为未登录错误
func IsNotLoginError(err error) bool {
return errors.Is(err, ErrNotLogin)
}
// IsPermissionDeniedError Checks if error is a permission denied error | 检查是否为权限拒绝错误
func IsPermissionDeniedError(err error) bool {
return errors.Is(err, ErrPermissionDenied)
}
// IsAccountDisabledError Checks if error is an account disabled error | 检查是否为账号禁用错误
func IsAccountDisabledError(err error) bool {
return errors.Is(err, ErrAccountDisabled)
}
// IsTokenError Checks if error is a token-related error | 检查是否为Token相关错误
func IsTokenError(err error) bool {
return errors.Is(err, ErrTokenInvalid) || errors.Is(err, ErrTokenExpired)
}
// GetErrorCode Extracts error code from SaTokenError | 从SaTokenError中提取错误码
func GetErrorCode(err error) int {
var saErr *SaTokenError
if errors.As(err, &saErr) {
return saErr.Code
}
return CodeServerError
}
// ============ Error Code Definitions | 错误码定义 ============
const (
// Standard HTTP status codes | 标准 HTTP 状态码
CodeSuccess = 200 // Request successful | 请求成功
+146 -15
View File
@@ -46,12 +46,12 @@ const (
// EventData contains information about a triggered event | 事件数据,包含触发事件的相关信息
type EventData struct {
Event Event // Event type | 事件类型
LoginID string // User login ID | 用户登录ID
Device string // Device identifier | 设备标识
Token string // Authentication token | 认证Token
Extra map[string]interface{} // Additional custom data | 额外的自定义数据
Timestamp int64 // Unix timestamp when event was triggered | 事件触发的Unix时间戳
Event Event // Event type | 事件类型
LoginID string // User login ID | 用户登录ID
Device string // Device identifier | 设备标识
Token string // Authentication token | 认证Token
Extra map[string]any // Additional custom data | 额外的自定义数据
Timestamp int64 // Unix timestamp when event was triggered | 事件触发的Unix时间戳
}
// String returns a string representation of the event data | 返回事件数据的字符串表示
@@ -87,35 +87,106 @@ type listenerEntry struct {
config ListenerConfig
}
// EventFilter is a function that decides whether an event should be processed | 事件过滤器,决定事件是否应该被处理
type EventFilter func(data *EventData) bool
// EventStats contains statistics about event processing | 事件统计信息
type EventStats struct {
TotalTriggered int64 // Total number of events triggered | 触发的事件总数
EventCounts map[Event]int64 // Count per event type | 各类型事件的计数
LastTriggered map[Event]time.Time // Last trigger time per event | 各类型事件的最后触发时间
}
// Manager manages event listeners and dispatches events | 事件管理器,管理监听器并分发事件
type Manager struct {
mu sync.RWMutex
listeners map[Event][]listenerEntry
panicHandler func(event Event, data *EventData, recovered interface{})
panicHandler func(event Event, data *EventData, recovered any)
listenerCounter int
enabledEvents map[Event]bool // If nil, all events are enabled | 如果为nil,所有事件都启用
asyncWaitGroup sync.WaitGroup // For waiting on async listeners during shutdown | 用于等待异步监听器完成
filters []EventFilter // Global event filters | 全局事件过滤器
stats *EventStats // Event statistics | 事件统计
enableStats bool // Whether to collect statistics | 是否收集统计信息
}
// NewManager creates a new event manager | 创建新的事件管理器
func NewManager() *Manager {
return &Manager{
listeners: make(map[Event][]listenerEntry),
panicHandler: func(event Event, data *EventData, recovered interface{}) {
panicHandler: func(event Event, data *EventData, recovered any) {
// Default panic handler: log but don't crash | 默认panic处理器:记录日志但不崩溃
fmt.Printf("sa-token: listener panic recovered: event=%s, panic=%v\n", event, recovered)
},
enabledEvents: nil, // All events enabled by default | 默认启用所有事件
filters: make([]EventFilter, 0),
stats: &EventStats{
EventCounts: make(map[Event]int64),
LastTriggered: make(map[Event]time.Time),
},
enableStats: false, // Stats disabled by default | 默认不启用统计
}
}
// SetPanicHandler sets a custom panic handler for listener errors | 设置自定义的panic处理器
func (m *Manager) SetPanicHandler(handler func(event Event, data *EventData, recovered interface{})) {
func (m *Manager) SetPanicHandler(handler func(event Event, data *EventData, recovered any)) {
m.mu.Lock()
defer m.mu.Unlock()
m.panicHandler = handler
}
// AddFilter adds a global event filter | 添加全局事件过滤器
func (m *Manager) AddFilter(filter EventFilter) {
m.mu.Lock()
defer m.mu.Unlock()
m.filters = append(m.filters, filter)
}
// ClearFilters removes all event filters | 清除所有事件过滤器
func (m *Manager) ClearFilters() {
m.mu.Lock()
defer m.mu.Unlock()
m.filters = make([]EventFilter, 0)
}
// EnableStats enables event statistics collection | 启用事件统计
func (m *Manager) EnableStats(enable bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.enableStats = enable
}
// GetStats returns a copy of event statistics | 获取事件统计信息副本
func (m *Manager) GetStats() EventStats {
m.mu.RLock()
defer m.mu.RUnlock()
stats := EventStats{
TotalTriggered: m.stats.TotalTriggered,
EventCounts: make(map[Event]int64),
LastTriggered: make(map[Event]time.Time),
}
for event, count := range m.stats.EventCounts {
stats.EventCounts[event] = count
}
for event, t := range m.stats.LastTriggered {
stats.LastTriggered[event] = t
}
return stats
}
// ResetStats resets event statistics | 重置事件统计
func (m *Manager) ResetStats() {
m.mu.Lock()
defer m.mu.Unlock()
m.stats = &EventStats{
EventCounts: make(map[Event]int64),
LastTriggered: make(map[Event]time.Time),
}
}
// EnableEvent enables specific events (disables all others) | 启用特定事件(禁用其他所有事件)
// Call with no arguments to enable all events | 不传参数时启用所有事件
func (m *Manager) EnableEvent(events ...Event) {
@@ -230,13 +301,15 @@ func (m *Manager) Unregister(listenerID string) bool {
// sortListeners sorts listeners by priority (descending)
func (m *Manager) sortListeners(event Event) {
entries := m.listeners[event]
// Simple bubble sort (listeners count is usually small)
for i := 0; i < len(entries)-1; i++ {
for j := 0; j < len(entries)-i-1; j++ {
if entries[j].config.Priority < entries[j+1].config.Priority {
entries[j], entries[j+1] = entries[j+1], entries[j]
}
// Use insertion sort (efficient for small lists and maintains stability)
for i := 1; i < len(entries); i++ {
key := entries[i]
j := i - 1
for j >= 0 && entries[j].config.Priority < key.config.Priority {
entries[j+1] = entries[j]
j--
}
entries[j+1] = key
}
}
@@ -255,6 +328,21 @@ func (m *Manager) Trigger(data *EventData) {
data.Timestamp = time.Now().Unix()
}
// Apply filters
for _, filter := range m.filters {
if !filter(data) {
m.mu.RUnlock()
return // Event filtered out
}
}
// Update statistics
if m.enableStats {
m.stats.TotalTriggered++
m.stats.EventCounts[data.Event]++
m.stats.LastTriggered[data.Event] = time.Now()
}
// Collect listeners to call
var listenersToCall []listenerEntry
@@ -281,6 +369,17 @@ func (m *Manager) Trigger(data *EventData) {
}
}
// TriggerAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回
func (m *Manager) TriggerAsync(data *EventData) {
go m.Trigger(data)
}
// TriggerSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成
func (m *Manager) TriggerSync(data *EventData) {
m.Trigger(data)
m.Wait()
}
// safeCall executes a listener with panic recovery
func (m *Manager) safeCall(listener Listener, data *EventData, wg *sync.WaitGroup) {
if wg != nil {
@@ -339,3 +438,35 @@ func (m *Manager) CountForEvent(event Event) int {
defer m.mu.RUnlock()
return len(m.listeners[event])
}
// GetListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID
func (m *Manager) GetListenerIDs(event Event) []string {
m.mu.RLock()
defer m.mu.RUnlock()
entries := m.listeners[event]
ids := make([]string, 0, len(entries))
for _, entry := range entries {
ids = append(ids, entry.config.ID)
}
return ids
}
// GetAllEvents returns all events that have registered listeners | 获取所有已注册监听器的事件
func (m *Manager) GetAllEvents() []Event {
m.mu.RLock()
defer m.mu.RUnlock()
events := make([]Event, 0, len(m.listeners))
for event := range m.listeners {
events = append(events, event)
}
return events
}
// HasListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器
func (m *Manager) HasListeners(event Event) bool {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.listeners[event]) > 0
}
+226 -154
View File
@@ -14,16 +14,48 @@ import (
"github.com/click33/sa-token-go/core/token"
)
// TokenInfo Token信息
// Constants for storage keys and default values | 存储键和默认值常量
const (
DefaultDevice = "default"
DefaultPrefix = "satoken:"
DisableValue = "1"
DefaultNonceTTL = 5 * time.Minute
// Key prefixes | 键前缀
TokenKeyPrefix = "token:"
AccountKeyPrefix = "account:"
DisableKeyPrefix = "disable:"
// Session keys | Session键
SessionKeyLoginID = "loginId"
SessionKeyDevice = "device"
SessionKeyLoginTime = "loginTime"
SessionKeyPermissions = "permissions"
SessionKeyRoles = "roles"
// Wildcard for permissions | 权限通配符
PermissionWildcard = "*"
PermissionSeparator = ":"
)
// Error variables | 错误变量
var (
ErrAccountDisabled = fmt.Errorf("account is disabled")
ErrNotLogin = fmt.Errorf("not login")
ErrTokenNotFound = fmt.Errorf("token not found")
ErrInvalidTokenData = fmt.Errorf("invalid token data")
)
// TokenInfo Token information | Token信息
type TokenInfo struct {
LoginID string `json:"loginId"`
Device string `json:"device"`
CreateTime int64 `json:"createTime"`
ActiveTime int64 `json:"activeTime"` // 最后活跃时间
ActiveTime int64 `json:"activeTime"` // Last active time | 最后活跃时间
Tag string `json:"tag,omitempty"`
}
// Manager 认证管理器
// Manager Authentication manager | 认证管理器
type Manager struct {
storage adapter.Storage
config *config.Config
@@ -34,7 +66,7 @@ type Manager struct {
oauth2Server *oauth2.OAuth2Server
}
// NewManager 创建管理器
// NewManager Creates a new manager | 创建管理器
func NewManager(storage adapter.Storage, cfg *config.Config) *Manager {
if cfg == nil {
cfg = config.DefaultConfig()
@@ -44,46 +76,63 @@ func NewManager(storage adapter.Storage, cfg *config.Config) *Manager {
storage: storage,
config: cfg,
generator: token.NewGenerator(cfg),
prefix: "satoken:",
nonceManager: security.NewNonceManager(storage, 5*time.Minute),
prefix: DefaultPrefix,
nonceManager: security.NewNonceManager(storage, DefaultNonceTTL),
refreshManager: security.NewRefreshTokenManager(storage, cfg),
oauth2Server: oauth2.NewOAuth2Server(storage),
}
}
// ============ 登录认证 ============
// ============ Helper Methods | 辅助方法 ============
// Login 登录,返回Token
// getDevice extracts device type from optional parameter | 从可选参数中提取设备类型
func getDevice(device []string) string {
if len(device) > 0 && device[0] != "" {
return device[0]
}
return DefaultDevice
}
// getExpiration calculates expiration duration from config | 从配置计算过期时间
func (m *Manager) getExpiration() time.Duration {
if m.config.Timeout > 0 {
return time.Duration(m.config.Timeout) * time.Second
}
return 0
}
// assertString safely converts interface to string | 安全地将interface转换为string
func assertString(v any) (string, bool) {
s, ok := v.(string)
return s, ok
}
// ============ Login Authentication | 登录认证 ============
// Login Performs user login and returns token | 登录,返回Token
func (m *Manager) Login(loginID string, device ...string) (string, error) {
deviceType := "default"
if len(device) > 0 {
deviceType = device[0]
}
deviceType := getDevice(device)
// 检查是否被封禁
// Check if account is disabled | 检查是否被封禁
if m.IsDisable(loginID) {
return "", fmt.Errorf("account is disabled")
return "", ErrAccountDisabled
}
// 如果不允许并发登录,先踢掉旧的
// Kick out old session if concurrent login is not allowed | 如果不允许并发登录,先踢掉旧的
if !m.config.IsConcurrent {
m.kickout(loginID, deviceType)
}
// 生成Token
// Generate token | 生成Token
tokenValue, err := m.generator.Generate(loginID, deviceType)
if err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
// 计算过期时间
var expiration time.Duration
if m.config.Timeout > 0 {
expiration = time.Duration(m.config.Timeout) * time.Second
}
expiration := m.getExpiration()
now := time.Now().Unix()
// 保存Token信息
// Save token info | 保存Token信息
tokenInfo := &TokenInfo{
LoginID: loginID,
Device: deviceType,
@@ -95,34 +144,27 @@ func (m *Manager) Login(loginID string, device ...string) (string, error) {
return "", err
}
// 保存账号-Token映射
// Save account-token mapping | 保存账号-Token映射
accountKey := m.getAccountKey(loginID, deviceType)
if err := m.storage.Set(accountKey, tokenValue, expiration); err != nil {
return "", fmt.Errorf("failed to save account mapping: %w", err)
}
// 创建Session
// Create session | 创建Session
sess := session.NewSession(loginID, m.storage, m.prefix)
sess.Set("loginId", loginID)
sess.Set("device", deviceType)
sess.Set("loginTime", now)
sess.Set(SessionKeyLoginID, loginID)
sess.Set(SessionKeyDevice, deviceType)
sess.Set(SessionKeyLoginTime, now)
return tokenValue, nil
}
// LoginByToken 使用指定Token登录(用于token无感刷新)
// LoginByToken Login with specified token (for seamless token refresh) | 使用指定Token登录(用于token无感刷新)
func (m *Manager) LoginByToken(loginID string, tokenValue string, device ...string) error {
deviceType := "default"
if len(device) > 0 {
deviceType = device[0]
}
var expiration time.Duration
if m.config.Timeout > 0 {
expiration = time.Duration(m.config.Timeout) * time.Second
}
deviceType := getDevice(device)
expiration := m.getExpiration()
now := time.Now().Unix()
tokenInfo := &TokenInfo{
LoginID: loginID,
Device: deviceType,
@@ -138,37 +180,41 @@ func (m *Manager) LoginByToken(loginID string, tokenValue string, device ...stri
return m.storage.Set(accountKey, tokenValue, expiration)
}
// Logout 登出
// Logout Performs user logout | 登出
func (m *Manager) Logout(loginID string, device ...string) error {
deviceType := "default"
if len(device) > 0 {
deviceType = device[0]
}
deviceType := getDevice(device)
accountKey := m.getAccountKey(loginID, deviceType)
tokenValue, err := m.storage.Get(accountKey)
if err != nil || tokenValue == nil {
return nil // 已经登出
return nil // Already logged out | 已经登出
}
// 删除Token
tokenKey := m.getTokenKey(tokenValue.(string))
// Delete token | 删除Token
tokenStr, ok := assertString(tokenValue)
if !ok {
return nil
}
tokenKey := m.getTokenKey(tokenStr)
m.storage.Delete(tokenKey)
// 删除账号映射
// Delete account mapping | 删除账号映射
m.storage.Delete(accountKey)
return nil
}
// LogoutByToken 根据Token登出
// LogoutByToken Logout by token | 根据Token登出
func (m *Manager) LogoutByToken(tokenValue string) error {
if tokenValue == "" {
return nil
}
tokenKey := m.getTokenKey(tokenValue)
m.storage.Delete(tokenKey)
return nil
return m.storage.Delete(tokenKey)
}
// kickout 踢人下线
// kickout Kick user offline (private) | 踢人下线(私有)
func (m *Manager) kickout(loginID string, device string) error {
accountKey := m.getAccountKey(loginID, device)
tokenValue, err := m.storage.Get(accountKey)
@@ -176,34 +222,35 @@ func (m *Manager) kickout(loginID string, device string) error {
return nil
}
tokenKey := m.getTokenKey(tokenValue.(string))
tokenStr, ok := assertString(tokenValue)
if !ok {
return nil
}
tokenKey := m.getTokenKey(tokenStr)
return m.storage.Delete(tokenKey)
}
// Kickout 踢人下线(公开方法)
// Kickout Kick user offline (public method) | 踢人下线(公开方法)
func (m *Manager) Kickout(loginID string, device ...string) error {
deviceType := "default"
if len(device) > 0 {
deviceType = device[0]
}
deviceType := getDevice(device)
return m.kickout(loginID, deviceType)
}
// ============ Token验证 ============
// ============ Token Validation | Token验证 ============
// IsLogin 检查是否登录
// IsLogin Checks if user is logged in | 检查是否登录
func (m *Manager) IsLogin(tokenValue string) bool {
if tokenValue == "" {
return false
}
tokenKey := m.getTokenKey(tokenValue)
exists := m.storage.Exists(tokenKey)
if !exists {
if !m.storage.Exists(tokenKey) {
return false
}
// 更新活跃时间并检查活跃超时
// Check and update active timeout | 更新活跃时间并检查活跃超时
if m.config.ActiveTimeout > 0 {
info, _ := m.getTokenInfo(tokenValue)
if info != nil {
@@ -215,38 +262,41 @@ func (m *Manager) IsLogin(tokenValue string) bool {
}
}
// 异步自动续期(提高性能)
// Async auto-renew for better performance | 异步自动续期(提高性能)
if m.config.AutoRenew && m.config.Timeout > 0 {
go func() {
expiration := time.Duration(m.config.Timeout) * time.Second
// 延长Token存储的过期时间
m.storage.Expire(tokenKey, expiration)
// 更新活跃时间
info, _ := m.getTokenInfo(tokenValue)
if info != nil {
info.ActiveTime = time.Now().Unix()
m.saveTokenInfo(tokenValue, info, expiration)
}
}()
go m.renewToken(tokenValue, tokenKey)
}
return true
}
// CheckLogin 检查登录(未登录抛出错误)
// renewToken Renews token expiration asynchronously | 异步续期Token
func (m *Manager) renewToken(tokenValue, tokenKey string) {
expiration := m.getExpiration()
// Extend token storage expiration | 延长Token存储的过期时间
m.storage.Expire(tokenKey, expiration)
// Update active time | 更新活跃时间
info, _ := m.getTokenInfo(tokenValue)
if info != nil {
info.ActiveTime = time.Now().Unix()
m.saveTokenInfo(tokenValue, info, expiration)
}
}
// CheckLogin Checks login status (throws error if not logged in) | 检查登录(未登录抛出错误)
func (m *Manager) CheckLogin(tokenValue string) error {
if !m.IsLogin(tokenValue) {
return fmt.Errorf("not login")
return ErrNotLogin
}
return nil
}
// GetLoginID 根据Token获取登录ID
// GetLoginID Gets login ID from token | 根据Token获取登录ID
func (m *Manager) GetLoginID(tokenValue string) (string, error) {
if !m.IsLogin(tokenValue) {
return "", fmt.Errorf("not login")
return "", ErrNotLogin
}
info, err := m.getTokenInfo(tokenValue)
@@ -257,7 +307,7 @@ func (m *Manager) GetLoginID(tokenValue string) (string, error) {
return info.LoginID, nil
}
// GetLoginIDNotCheck 获取登录ID(不检查Token是否有效)
// GetLoginIDNotCheck Gets login ID without checking token validity | 获取登录ID(不检查Token是否有效)
func (m *Manager) GetLoginIDNotCheck(tokenValue string) (string, error) {
info, err := m.getTokenInfo(tokenValue)
if err != nil {
@@ -266,50 +316,52 @@ func (m *Manager) GetLoginIDNotCheck(tokenValue string) (string, error) {
return info.LoginID, nil
}
// GetTokenValue 根据登录ID获取Token
// GetTokenValue Gets token by login ID | 根据登录ID获取Token
func (m *Manager) GetTokenValue(loginID string, device ...string) (string, error) {
deviceType := "default"
if len(device) > 0 {
deviceType = device[0]
}
deviceType := getDevice(device)
accountKey := m.getAccountKey(loginID, deviceType)
tokenValue, err := m.storage.Get(accountKey)
if err != nil || tokenValue == nil {
return "", fmt.Errorf("token not found for login id: %s", loginID)
}
return tokenValue.(string), nil
tokenStr, ok := assertString(tokenValue)
if !ok {
return "", fmt.Errorf("invalid token value type")
}
return tokenStr, nil
}
// GetTokenInfo 获取Token信息
// GetTokenInfo Gets token information | 获取Token信息
func (m *Manager) GetTokenInfo(tokenValue string) (*TokenInfo, error) {
return m.getTokenInfo(tokenValue)
}
// ============ 账号封禁 ============
// ============ Account Disable | 账号封禁 ============
// Disable 封禁账号
// Disable Disables an account | 封禁账号
func (m *Manager) Disable(loginID string, duration time.Duration) error {
key := m.prefix + "disable:" + loginID
return m.storage.Set(key, "1", duration)
key := m.getDisableKey(loginID)
return m.storage.Set(key, DisableValue, duration)
}
// Untie 解封账号
// Untie Re-enables a disabled account | 解封账号
func (m *Manager) Untie(loginID string) error {
key := m.prefix + "disable:" + loginID
key := m.getDisableKey(loginID)
return m.storage.Delete(key)
}
// IsDisable 检查账号是否被封禁
// IsDisable Checks if account is disabled | 检查账号是否被封禁
func (m *Manager) IsDisable(loginID string) bool {
key := m.prefix + "disable:" + loginID
key := m.getDisableKey(loginID)
return m.storage.Exists(key)
}
// GetDisableTime 获取账号剩余封禁时间(秒)
// GetDisableTime Gets remaining disable time in seconds | 获取账号剩余封禁时间(秒)
func (m *Manager) GetDisableTime(loginID string) (int64, error) {
key := m.prefix + "disable:" + loginID
key := m.getDisableKey(loginID)
ttl, err := m.storage.TTL(key)
if err != nil {
return -2, err
@@ -317,9 +369,14 @@ func (m *Manager) GetDisableTime(loginID string) (int64, error) {
return int64(ttl.Seconds()), nil
}
// ============ Session管理 ============
// getDisableKey Gets disable storage key | 获取禁用存储键
func (m *Manager) getDisableKey(loginID string) string {
return m.prefix + DisableKeyPrefix + loginID
}
// GetSession 获取Session
// ============ Session Management | Session管理 ============
// GetSession Gets session by login ID | 获取Session
func (m *Manager) GetSession(loginID string) (*session.Session, error) {
sess, err := session.Load(loginID, m.storage, m.prefix)
if err != nil {
@@ -328,7 +385,7 @@ func (m *Manager) GetSession(loginID string) (*session.Session, error) {
return sess, nil
}
// GetSessionByToken 根据Token获取Session
// GetSessionByToken Gets session by token | 根据Token获取Session
func (m *Manager) GetSessionByToken(tokenValue string) (*session.Session, error) {
loginID, err := m.GetLoginID(tokenValue)
if err != nil {
@@ -337,7 +394,7 @@ func (m *Manager) GetSessionByToken(tokenValue string) (*session.Session, error)
return m.GetSession(loginID)
}
// DeleteSession 删除Session
// DeleteSession Deletes session | 删除Session
func (m *Manager) DeleteSession(loginID string) error {
sess, err := m.GetSession(loginID)
if err != nil {
@@ -346,25 +403,25 @@ func (m *Manager) DeleteSession(loginID string) error {
return sess.Destroy()
}
// ============ 权限验证 ============
// ============ Permission Validation | 权限验证 ============
// SetPermissions 设置权限
// SetPermissions Sets permissions for user | 设置权限
func (m *Manager) SetPermissions(loginID string, permissions []string) error {
sess, err := m.GetSession(loginID)
if err != nil {
return err
}
return sess.Set("permissions", permissions)
return sess.Set(SessionKeyPermissions, permissions)
}
// GetPermissions 获取权限列表
// GetPermissions Gets permission list | 获取权限列表
func (m *Manager) GetPermissions(loginID string) ([]string, error) {
sess, err := m.GetSession(loginID)
if err != nil {
return nil, err
}
perms, exists := sess.Get("permissions")
perms, exists := sess.Get(SessionKeyPermissions)
if !exists {
return []string{}, nil
}
@@ -408,27 +465,29 @@ func (m *Manager) HasPermissionsOr(loginID string, permissions []string) bool {
return false
}
// matchPermission 权限匹配(支持通配符)
// matchPermission Matches permission with wildcards support | 权限匹配(支持通配符)
func (m *Manager) matchPermission(pattern, permission string) bool {
if pattern == "*" || pattern == permission {
// Exact match or wildcard | 精确匹配或通配符
if pattern == PermissionWildcard || pattern == permission {
return true
}
// 支持通配符,例如 user:* 匹配 user:add, user:delete等
if strings.HasSuffix(pattern, ":*") {
prefix := strings.TrimSuffix(pattern, "*")
// Pattern like "user:*" matches "user:add", "user:delete", etc. | 支持通配符,例如 user:* 匹配 user:add, user:delete等
wildcardSuffix := PermissionSeparator + PermissionWildcard
if strings.HasSuffix(pattern, wildcardSuffix) {
prefix := strings.TrimSuffix(pattern, PermissionWildcard)
return strings.HasPrefix(permission, prefix)
}
// 支持 user:*:view 这样的模式
if strings.Contains(pattern, "*") {
parts := strings.Split(pattern, ":")
permParts := strings.Split(permission, ":")
// Pattern like "user:*:view" | 支持 user:*:view 这样的模式
if strings.Contains(pattern, PermissionWildcard) {
parts := strings.Split(pattern, PermissionSeparator)
permParts := strings.Split(permission, PermissionSeparator)
if len(parts) != len(permParts) {
return false
}
for i, part := range parts {
if part != "*" && part != permParts[i] {
if part != PermissionWildcard && part != permParts[i] {
return false
}
}
@@ -438,25 +497,25 @@ func (m *Manager) matchPermission(pattern, permission string) bool {
return false
}
// ============ 角色验证 ============
// ============ Role Validation | 角色验证 ============
// SetRoles 设置角色
// SetRoles Sets roles for user | 设置角色
func (m *Manager) SetRoles(loginID string, roles []string) error {
sess, err := m.GetSession(loginID)
if err != nil {
return err
}
return sess.Set("roles", roles)
return sess.Set(SessionKeyRoles, roles)
}
// GetRoles 获取角色列表
// GetRoles Gets role list | 获取角色列表
func (m *Manager) GetRoles(loginID string) ([]string, error) {
sess, err := m.GetSession(loginID)
if err != nil {
return nil, err
}
roles, exists := sess.Get("roles")
roles, exists := sess.Get(SessionKeyRoles)
if !exists {
return []string{}, nil
}
@@ -499,9 +558,9 @@ func (m *Manager) HasRolesOr(loginID string, roles []string) bool {
return false
}
// ============ Token标签 ============
// ============ Token Tags | Token标签 ============
// SetTokenTag 设置Token标签
// SetTokenTag Sets token tag | 设置Token标签
func (m *Manager) SetTokenTag(tokenValue, tag string) error {
info, err := m.getTokenInfo(tokenValue)
if err != nil {
@@ -509,16 +568,12 @@ func (m *Manager) SetTokenTag(tokenValue, tag string) error {
}
info.Tag = tag
var expiration time.Duration
if m.config.Timeout > 0 {
expiration = time.Duration(m.config.Timeout) * time.Second
}
expiration := m.getExpiration()
return m.saveTokenInfo(tokenValue, info, expiration)
}
// GetTokenTag 获取Token标签
// GetTokenTag Gets token tag | 获取Token标签
func (m *Manager) GetTokenTag(tokenValue string) (string, error) {
info, err := m.getTokenInfo(tokenValue)
if err != nil {
@@ -527,28 +582,30 @@ func (m *Manager) GetTokenTag(tokenValue string) (string, error) {
return info.Tag, nil
}
// ============ 会话查询 ============
// ============ Session Query | 会话查询 ============
// GetTokenValueListByLoginID 获取指定账号的所有Token
// GetTokenValueListByLoginID Gets all tokens for specified account | 获取指定账号的所有Token
func (m *Manager) GetTokenValueListByLoginID(loginID string) ([]string, error) {
pattern := m.prefix + "account:" + loginID + ":*"
pattern := m.prefix + AccountKeyPrefix + loginID + ":*"
keys, err := m.storage.Keys(pattern)
if err != nil {
return nil, err
}
tokens := make([]string, 0)
tokens := make([]string, 0, len(keys))
for _, key := range keys {
value, err := m.storage.Get(key)
if err == nil && value != nil {
tokens = append(tokens, value.(string))
if tokenStr, ok := assertString(value); ok {
tokens = append(tokens, tokenStr)
}
}
}
return tokens, nil
}
// GetSessionCountByLoginID 获取指定账号的Session数量
// GetSessionCountByLoginID Gets session count for specified account | 获取指定账号的Session数量
func (m *Manager) GetSessionCountByLoginID(loginID string) (int, error) {
tokens, err := m.GetTokenValueListByLoginID(loginID)
if err != nil {
@@ -557,19 +614,19 @@ func (m *Manager) GetSessionCountByLoginID(loginID string) (int, error) {
return len(tokens), nil
}
// ============ 辅助方法 ============
// ============ Internal Helper Methods | 内部辅助方法 ============
// getTokenKey 获取Token存储键
// getTokenKey Gets token storage key | 获取Token存储键
func (m *Manager) getTokenKey(tokenValue string) string {
return m.prefix + "token:" + tokenValue
return m.prefix + TokenKeyPrefix + tokenValue
}
// getAccountKey 获取账号存储键
// getAccountKey Gets account storage key | 获取账号存储键
func (m *Manager) getAccountKey(loginID, device string) string {
return m.prefix + "account:" + loginID + ":" + device
return m.prefix + AccountKeyPrefix + loginID + PermissionSeparator + device
}
// saveTokenInfo 保存Token信息
// saveTokenInfo Saves token information | 保存Token信息
func (m *Manager) saveTokenInfo(tokenValue string, info *TokenInfo, expiration time.Duration) error {
data, err := json.Marshal(info)
if err != nil {
@@ -580,28 +637,33 @@ func (m *Manager) saveTokenInfo(tokenValue string, info *TokenInfo, expiration t
return m.storage.Set(tokenKey, string(data), expiration)
}
// getTokenInfo 获取Token信息
// getTokenInfo Gets token information | 获取Token信息
func (m *Manager) getTokenInfo(tokenValue string) (*TokenInfo, error) {
tokenKey := m.getTokenKey(tokenValue)
data, err := m.storage.Get(tokenKey)
if err != nil || data == nil {
return nil, fmt.Errorf("token not found")
return nil, ErrTokenNotFound
}
dataStr, ok := assertString(data)
if !ok {
return nil, ErrInvalidTokenData
}
var info TokenInfo
if err := json.Unmarshal([]byte(data.(string)), &info); err != nil {
return nil, fmt.Errorf("invalid token data: %w", err)
if err := json.Unmarshal([]byte(dataStr), &info); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidTokenData, err)
}
return &info, nil
}
// toStringSlice 将interface{}转换为[]string
func (m *Manager) toStringSlice(v interface{}) []string {
// toStringSlice Converts any to []string | 将any转换为[]string
func (m *Manager) toStringSlice(v any) []string {
switch val := v.(type) {
case []string:
return val
case []interface{}:
case []any:
result := make([]string, 0, len(val))
for _, item := range val {
if str, ok := item.(string); ok {
@@ -614,36 +676,46 @@ func (m *Manager) toStringSlice(v interface{}) []string {
}
}
// GetConfig 获取配置
// ============ Public Getters | 公共获取器 ============
// GetConfig Gets configuration | 获取配置
func (m *Manager) GetConfig() *config.Config {
return m.config
}
// GetStorage 获取存储
// GetStorage Gets storage | 获取存储
func (m *Manager) GetStorage() adapter.Storage {
return m.storage
}
// ============ Security Features | 安全特性 ============
// GenerateNonce Generates a one-time nonce | 生成一次性随机数
func (m *Manager) GenerateNonce() (string, error) {
return m.nonceManager.Generate()
}
// VerifyNonce Verifies a nonce | 验证随机数
func (m *Manager) VerifyNonce(nonce string) bool {
return m.nonceManager.Verify(nonce)
}
// LoginWithRefreshToken Logs in with refresh token | 使用刷新令牌登录
func (m *Manager) LoginWithRefreshToken(loginID, device string) (*security.RefreshTokenInfo, error) {
return m.refreshManager.GenerateTokenPair(loginID, device)
}
// RefreshAccessToken Refreshes access token | 刷新访问令牌
func (m *Manager) RefreshAccessToken(refreshToken string) (*security.RefreshTokenInfo, error) {
return m.refreshManager.RefreshAccessToken(refreshToken)
}
// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌
func (m *Manager) RevokeRefreshToken(refreshToken string) error {
return m.refreshManager.RevokeRefreshToken(refreshToken)
}
// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例
func (m *Manager) GetOAuth2Server() *oauth2.OAuth2Server {
return m.oauth2Server
}
+156 -58
View File
@@ -4,6 +4,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"sync"
"time"
"github.com/click33/sa-token-go/core/adapter"
@@ -20,11 +21,42 @@ import (
// 5. RefreshAccessToken() - Use refresh token to get new token | 用刷新令牌获取新令牌
//
// Usage | 用法:
// server := core.NewOAuth2Server(storage)
// server.RegisterClient(&core.OAuth2Client{...})
// server := oauth2.NewOAuth2Server(storage)
// server.RegisterClient(&oauth2.Client{...})
// authCode, _ := server.GenerateAuthorizationCode(...)
// token, _ := server.ExchangeCodeForToken(...)
// Constants for OAuth2 | OAuth2常量
const (
DefaultCodeExpiration = 10 * time.Minute // Authorization code expiration | 授权码过期时间
DefaultTokenExpiration = 2 * time.Hour // Access token expiration | 访问令牌过期时间
DefaultRefreshTTL = 30 * 24 * time.Hour // Refresh token expiration | 刷新令牌过期时间
CodeLength = 32 // Authorization code byte length | 授权码字节长度
AccessTokenLength = 32 // Access token byte length | 访问令牌字节长度
RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度
CodeKeyPrefix = "satoken:oauth2:code:" // Code storage key prefix | 授权码存储键前缀
TokenKeyPrefix = "satoken:oauth2:token:" // Token storage key prefix | 令牌存储键前缀
RefreshKeyPrefix = "satoken:oauth2:refresh:" // Refresh storage key prefix | 刷新令牌存储键前缀
TokenTypeBearer = "Bearer" // Token type | 令牌类型
)
// Error variables | 错误变量
var (
ErrClientNotFound = fmt.Errorf("client not found")
ErrInvalidRedirectURI = fmt.Errorf("invalid redirect_uri")
ErrInvalidClientCredentials = fmt.Errorf("invalid client credentials")
ErrInvalidAuthCode = fmt.Errorf("invalid authorization code")
ErrAuthCodeUsed = fmt.Errorf("authorization code already used")
ErrAuthCodeExpired = fmt.Errorf("authorization code expired")
ErrClientMismatch = fmt.Errorf("client mismatch")
ErrRedirectURIMismatch = fmt.Errorf("redirect_uri mismatch")
ErrInvalidAccessToken = fmt.Errorf("invalid access token")
ErrInvalidTokenData = fmt.Errorf("invalid token data")
)
// GrantType OAuth2 grant type | OAuth2授权类型
type GrantType string
@@ -71,55 +103,74 @@ type AccessToken struct {
type OAuth2Server struct {
storage adapter.Storage
clients map[string]*Client
clientsMu sync.RWMutex // Clients map lock | 客户端映射锁
codeExpiration time.Duration // Authorization code expiration (10min) | 授权码过期时间(10分钟)
tokenExpiration time.Duration // Access token expiration (2h) | 访问令牌过期时间(2小时)
}
// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器
// NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器
func NewOAuth2Server(storage adapter.Storage) *OAuth2Server {
return &OAuth2Server{
storage: storage,
clients: make(map[string]*Client),
codeExpiration: 10 * time.Minute, // Authorization code expires in 10 minutes | 授权码10分钟过期
tokenExpiration: 2 * time.Hour, // Access token expires in 2 hours | 访问令牌2小时过期
codeExpiration: DefaultCodeExpiration,
tokenExpiration: DefaultTokenExpiration,
}
}
// RegisterClient registers an OAuth2 client | 注册OAuth2客户端
func (s *OAuth2Server) RegisterClient(client *Client) {
// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端
func (s *OAuth2Server) RegisterClient(client *Client) error {
if client == nil || client.ClientID == "" {
return fmt.Errorf("invalid client: clientID is required")
}
s.clientsMu.Lock()
defer s.clientsMu.Unlock()
s.clients[client.ClientID] = client
return nil
}
// GetClient gets client by ID | 根据ID获取客户端
// UnregisterClient Unregisters an OAuth2 client | 注销OAuth2客户端
func (s *OAuth2Server) UnregisterClient(clientID string) {
s.clientsMu.Lock()
defer s.clientsMu.Unlock()
delete(s.clients, clientID)
}
// GetClient Gets client by ID | 根据ID获取客户端
func (s *OAuth2Server) GetClient(clientID string) (*Client, error) {
s.clientsMu.RLock()
defer s.clientsMu.RUnlock()
client, exists := s.clients[clientID]
if !exists {
return nil, fmt.Errorf("client not found")
return nil, ErrClientNotFound
}
return client, nil
}
// GenerateAuthorizationCode generates authorization code | 生成授权码
// GenerateAuthorizationCode Generates authorization code | 生成授权码
func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID string, scopes []string) (*AuthorizationCode, error) {
if userID == "" {
return nil, fmt.Errorf("userID cannot be empty")
}
client, err := s.GetClient(clientID)
if err != nil {
return nil, err
}
validRedirect := false
for _, uri := range client.RedirectURIs {
if uri == redirectURI {
validRedirect = true
break
}
}
if !validRedirect {
return nil, fmt.Errorf("invalid redirect_uri")
// Validate redirect URI | 验证回调URI
if !s.isValidRedirectURI(client, redirectURI) {
return nil, ErrInvalidRedirectURI
}
codeBytes := make([]byte, 32)
// Generate code | 生成授权码
codeBytes := make([]byte, CodeLength)
if _, err := rand.Read(codeBytes); err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate authorization code: %w", err)
}
code := hex.EncodeToString(codeBytes)
@@ -134,29 +185,41 @@ func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID s
Used: false,
}
key := fmt.Sprintf("satoken:oauth2:code:%s", code)
key := s.getCodeKey(code)
if err := s.storage.Set(key, authCode, s.codeExpiration); err != nil {
return nil, err
return nil, fmt.Errorf("failed to store authorization code: %w", err)
}
return authCode, nil
}
// ExchangeCodeForToken exchanges authorization code for access token | 用授权码换取访问令牌
// isValidRedirectURI Checks if redirect URI is valid for client | 检查回调URI是否有效
func (s *OAuth2Server) isValidRedirectURI(client *Client, redirectURI string) bool {
for _, uri := range client.RedirectURIs {
if uri == redirectURI {
return true
}
}
return false
}
// ExchangeCodeForToken Exchanges authorization code for access token | 用授权码换取访问令牌
func (s *OAuth2Server) ExchangeCodeForToken(code, clientID, clientSecret, redirectURI string) (*AccessToken, error) {
// Verify client credentials | 验证客户端凭证
client, err := s.GetClient(clientID)
if err != nil {
return nil, err
}
if client.ClientSecret != clientSecret {
return nil, fmt.Errorf("invalid client credentials")
return nil, ErrInvalidClientCredentials
}
key := fmt.Sprintf("satoken:oauth2:code:%s", code)
// Get authorization code | 获取授权码
key := s.getCodeKey(code)
data, err := s.storage.Get(key)
if err != nil {
return nil, fmt.Errorf("invalid authorization code")
if err != nil || data == nil {
return nil, ErrInvalidAuthCode
}
authCode, ok := data.(*AuthorizationCode)
@@ -164,45 +227,50 @@ func (s *OAuth2Server) ExchangeCodeForToken(code, clientID, clientSecret, redire
return nil, fmt.Errorf("invalid code data")
}
// Validate authorization code | 验证授权码
if authCode.Used {
return nil, fmt.Errorf("authorization code already used")
return nil, ErrAuthCodeUsed
}
if authCode.ClientID != clientID {
return nil, fmt.Errorf("client mismatch")
return nil, ErrClientMismatch
}
if authCode.RedirectURI != redirectURI {
return nil, fmt.Errorf("redirect_uri mismatch")
return nil, ErrRedirectURIMismatch
}
if time.Now().Unix() > authCode.CreateTime+authCode.ExpiresIn {
s.storage.Delete(key)
return nil, fmt.Errorf("authorization code expired")
return nil, ErrAuthCodeExpired
}
// Mark code as used | 标记为已使用
authCode.Used = true
s.storage.Set(key, authCode, time.Minute)
return s.generateAccessToken(authCode.UserID, authCode.ClientID, authCode.Scopes)
}
// generateAccessToken Generates access token and refresh token | 生成访问令牌和刷新令牌
func (s *OAuth2Server) generateAccessToken(userID, clientID string, scopes []string) (*AccessToken, error) {
tokenBytes := make([]byte, 32)
// Generate access token | 生成访问令牌
tokenBytes := make([]byte, AccessTokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
accessToken := hex.EncodeToString(tokenBytes)
refreshBytes := make([]byte, 32)
// Generate refresh token | 生成刷新令牌
refreshBytes := make([]byte, RefreshTokenLength)
if _, err := rand.Read(refreshBytes); err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
refreshToken := hex.EncodeToString(refreshBytes)
token := &AccessToken{
Token: accessToken,
TokenType: "Bearer",
TokenType: TokenTypeBearer,
ExpiresIn: int64(s.tokenExpiration.Seconds()),
RefreshToken: refreshToken,
Scopes: scopes,
@@ -210,50 +278,58 @@ func (s *OAuth2Server) generateAccessToken(userID, clientID string, scopes []str
ClientID: clientID,
}
tokenKey := fmt.Sprintf("satoken:oauth2:token:%s", accessToken)
refreshKey := fmt.Sprintf("satoken:oauth2:refresh:%s", refreshToken)
tokenKey := s.getTokenKey(accessToken)
refreshKey := s.getRefreshKey(refreshToken)
// Store access token | 存储访问令牌
if err := s.storage.Set(tokenKey, token, s.tokenExpiration); err != nil {
return nil, err
return nil, fmt.Errorf("failed to store access token: %w", err)
}
if err := s.storage.Set(refreshKey, token, 30*24*time.Hour); err != nil {
return nil, err
// Store refresh token | 存储刷新令牌
if err := s.storage.Set(refreshKey, token, DefaultRefreshTTL); err != nil {
return nil, fmt.Errorf("failed to store refresh token: %w", err)
}
return token, nil
}
// ValidateAccessToken validates access token | 验证访问令牌
// ValidateAccessToken Validates access token | 验证访问令牌
func (s *OAuth2Server) ValidateAccessToken(tokenString string) (*AccessToken, error) {
key := fmt.Sprintf("satoken:oauth2:token:%s", tokenString)
if tokenString == "" {
return nil, ErrInvalidAccessToken
}
key := s.getTokenKey(tokenString)
data, err := s.storage.Get(key)
if err != nil {
return nil, fmt.Errorf("invalid access token")
if err != nil || data == nil {
return nil, ErrInvalidAccessToken
}
token, ok := data.(*AccessToken)
if !ok {
return nil, fmt.Errorf("invalid token data")
return nil, ErrInvalidTokenData
}
return token, nil
}
// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌
// RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌
func (s *OAuth2Server) RefreshAccessToken(refreshToken, clientID, clientSecret string) (*AccessToken, error) {
// Verify client credentials | 验证客户端凭证
client, err := s.GetClient(clientID)
if err != nil {
return nil, err
}
if client.ClientSecret != clientSecret {
return nil, fmt.Errorf("invalid client credentials")
return nil, ErrInvalidClientCredentials
}
key := fmt.Sprintf("satoken:oauth2:refresh:%s", refreshToken)
// Get refresh token | 获取刷新令牌
key := s.getRefreshKey(refreshToken)
data, err := s.storage.Get(key)
if err != nil {
if err != nil || data == nil {
return nil, fmt.Errorf("invalid refresh token")
}
@@ -263,28 +339,50 @@ func (s *OAuth2Server) RefreshAccessToken(refreshToken, clientID, clientSecret s
}
if oldToken.ClientID != clientID {
return nil, fmt.Errorf("client mismatch")
return nil, ErrClientMismatch
}
oldTokenKey := fmt.Sprintf("satoken:oauth2:token:%s", oldToken.Token)
// Delete old access token | 删除旧的访问令牌
oldTokenKey := s.getTokenKey(oldToken.Token)
s.storage.Delete(oldTokenKey)
return s.generateAccessToken(oldToken.UserID, oldToken.ClientID, oldToken.Scopes)
}
// RevokeToken revokes access token and its refresh token | 撤销访问令牌及其刷新令牌
// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌
func (s *OAuth2Server) RevokeToken(tokenString string) error {
key := fmt.Sprintf("satoken:oauth2:token:%s", tokenString)
if tokenString == "" {
return nil
}
key := s.getTokenKey(tokenString)
data, err := s.storage.Get(key)
if err != nil {
return err
}
token, ok := data.(*AccessToken)
if ok && token.RefreshToken != "" {
refreshKey := fmt.Sprintf("satoken:oauth2:refresh:%s", token.RefreshToken)
// Revoke refresh token if exists | 如果存在则撤销刷新令牌
if token, ok := data.(*AccessToken); ok && token.RefreshToken != "" {
refreshKey := s.getRefreshKey(token.RefreshToken)
s.storage.Delete(refreshKey)
}
return s.storage.Delete(key)
}
// ============ Helper Methods | 辅助方法 ============
// getCodeKey Gets storage key for authorization code | 获取授权码的存储键
func (s *OAuth2Server) getCodeKey(code string) string {
return CodeKeyPrefix + code
}
// getTokenKey Gets storage key for access token | 获取访问令牌的存储键
func (s *OAuth2Server) getTokenKey(token string) string {
return TokenKeyPrefix + token
}
// getRefreshKey Gets storage key for refresh token | 获取刷新令牌的存储键
func (s *OAuth2Server) getRefreshKey(refreshToken string) string {
return RefreshKeyPrefix + refreshToken
}
+39 -14
View File
@@ -19,6 +19,7 @@ import (
// Version Sa-Token-Go version | Sa-Token-Go版本
const Version = "0.1.0"
// ============ Exported Types | 导出的类型 ============
// Export main types and functions for external use | 导出主要类型和函数,方便外部使用
// Configuration related types | 配置相关类型
@@ -96,59 +97,81 @@ const (
GrantTypePassword = oauth2.GrantTypePassword
)
// Utility functions | 工具函数
// ============ Utility Functions | 工具函数 ============
var (
RandomString = utils.RandomString
IsEmpty = utils.IsEmpty
IsNotEmpty = utils.IsNotEmpty
DefaultString = utils.DefaultString
// String utilities | 字符串工具
RandomString = utils.RandomString
RandomNumericString = utils.RandomNumericString
RandomAlphanumeric = utils.RandomAlphanumeric
IsEmpty = utils.IsEmpty
IsNotEmpty = utils.IsNotEmpty
DefaultString = utils.DefaultString
// Slice utilities | 切片工具
ContainsString = utils.ContainsString
RemoveString = utils.RemoveString
UniqueStrings = utils.UniqueStrings
MergeStrings = utils.MergeStrings
MatchPattern = utils.MatchPattern
FilterStrings = utils.FilterStrings
MapStrings = utils.MapStrings
// Pattern matching | 模式匹配
MatchPattern = utils.MatchPattern
// Duration utilities | 时长工具
FormatDuration = utils.FormatDuration
ParseDuration = utils.ParseDuration
// Hash & Encoding | 哈希和编码
SHA256Hash = utils.SHA256Hash
Base64Encode = utils.Base64Encode
Base64Decode = utils.Base64Decode
)
// DefaultConfig returns default configuration | 返回默认配置
// ============ Factory Functions | 工厂函数 ============
// DefaultConfig Returns default configuration | 返回默认配置
func DefaultConfig() *Config {
return config.DefaultConfig()
}
// NewManager creates a new authentication manager | 创建新的认证管理器
// NewManager Creates a new authentication manager | 创建新的认证管理器
func NewManager(storage Storage, cfg *Config) *Manager {
return manager.NewManager(storage, cfg)
}
// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文
// NewContext Creates a new Sa-Token context | 创建新的Sa-Token上下文
func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext {
return context.NewContext(ctx, mgr)
}
// NewSession creates a new session | 创建新的Session
// NewSession Creates a new session | 创建新的Session
func NewSession(id string, storage Storage, prefix string) *Session {
return session.NewSession(id, storage, prefix)
}
// LoadSession loads an existing session | 加载已存在的Session
// LoadSession Loads an existing session | 加载已存在的Session
func LoadSession(id string, storage Storage, prefix string) (*Session, error) {
return session.Load(id, storage, prefix)
}
// NewTokenGenerator creates a new token generator | 创建新的Token生成器
// NewTokenGenerator Creates a new token generator | 创建新的Token生成器
func NewTokenGenerator(cfg *Config) *TokenGenerator {
return token.NewGenerator(cfg)
}
// NewEventManager creates a new event manager | 创建新的事件管理器
// NewEventManager Creates a new event manager | 创建新的事件管理器
func NewEventManager() *EventManager {
return listener.NewManager()
}
// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置)
// NewBuilder Creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置)
func NewBuilder() *Builder {
return builder.NewBuilder()
}
// NewNonceManager Creates a new nonce manager | 创建新的Nonce管理器
func NewNonceManager(storage Storage, ttl ...int64) *NonceManager {
var duration time.Duration
if len(ttl) > 0 && ttl[0] > 0 {
@@ -157,10 +180,12 @@ func NewNonceManager(storage Storage, ttl ...int64) *NonceManager {
return security.NewNonceManager(storage, duration)
}
// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器
func NewRefreshTokenManager(storage Storage, cfg *Config) *RefreshTokenManager {
return security.NewRefreshTokenManager(storage, cfg)
}
// NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器
func NewOAuth2Server(storage Storage) *OAuth2Server {
return oauth2.NewOAuth2Server(storage)
}
+40 -13
View File
@@ -23,6 +23,18 @@ import (
// valid := manager.VerifyNonce(nonce) // true
// valid = manager.VerifyNonce(nonce) // false (replay prevented)
// Constants for nonce | Nonce常量
const (
DefaultNonceTTL = 5 * time.Minute // Default nonce expiration | 默认nonce过期时间
NonceLength = 32 // Nonce byte length | Nonce字节长度
NonceKeyPrefix = "satoken:nonce:" // Storage key prefix | 存储键前缀
)
// Error variables | 错误变量
var (
ErrInvalidNonce = fmt.Errorf("invalid or expired nonce")
)
// NonceManager Nonce manager for anti-replay attacks | Nonce管理器,用于防重放攻击
type NonceManager struct {
storage adapter.Storage
@@ -30,11 +42,11 @@ type NonceManager struct {
mu sync.RWMutex
}
// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器
// NewNonceManager Creates a new nonce manager | 创建新的Nonce管理器
// ttl: time to live, default 5 minutes | 过期时间,默认5分钟
func NewNonceManager(storage adapter.Storage, ttl time.Duration) *NonceManager {
if ttl == 0 {
ttl = 5 * time.Minute
ttl = DefaultNonceTTL
}
return &NonceManager{
storage: storage,
@@ -42,31 +54,31 @@ func NewNonceManager(storage adapter.Storage, ttl time.Duration) *NonceManager {
}
}
// Generate generates a new nonce and stores it | 生成新的nonce并存储
// Generate Generates a new nonce and stores it | 生成新的nonce并存储
// Returns 64-char hex string | 返回64字符的十六进制字符串
func (nm *NonceManager) Generate() (string, error) {
bytes := make([]byte, 32)
bytes := make([]byte, NonceLength)
if _, err := rand.Read(bytes); err != nil {
return "", err
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
nonce := hex.EncodeToString(bytes)
key := fmt.Sprintf("satoken:nonce:%s", nonce)
key := nm.getNonceKey(nonce)
if err := nm.storage.Set(key, time.Now().Unix(), nm.ttl); err != nil {
return "", err
return "", fmt.Errorf("failed to store nonce: %w", err)
}
return nonce, nil
}
// Verify verifies nonce and consumes it (one-time use) | 验证nonce并消费它(一次性使用)
// Verify Verifies nonce and consumes it (one-time use) | 验证nonce并消费它(一次性使用)
// Returns false if nonce doesn't exist or already used | 如果nonce不存在或已使用则返回false
func (nm *NonceManager) Verify(nonce string) bool {
if nonce == "" {
return false
}
key := fmt.Sprintf("satoken:nonce:%s", nonce)
key := nm.getNonceKey(nonce)
nm.mu.Lock()
defer nm.mu.Unlock()
@@ -79,14 +91,29 @@ func (nm *NonceManager) Verify(nonce string) bool {
return true
}
// VerifyAndConsume verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误
// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误
func (nm *NonceManager) VerifyAndConsume(nonce string) error {
if !nm.Verify(nonce) {
return fmt.Errorf("invalid or expired nonce")
return ErrInvalidNonce
}
return nil
}
// Clean cleans expired nonces (handled by storage TTL) | 清理过期的nonce(由存储的TTL处理
func (nm *NonceManager) Clean() {
// IsValid Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费
func (nm *NonceManager) IsValid(nonce string) bool {
if nonce == "" {
return false
}
key := nm.getNonceKey(nonce)
nm.mu.RLock()
defer nm.mu.RUnlock()
return nm.storage.Exists(key)
}
// getNonceKey Gets storage key for nonce | 获取nonce的存储键
func (nm *NonceManager) getNonceKey(nonce string) string {
return NonceKeyPrefix + nonce
}
+77 -28
View File
@@ -25,6 +25,21 @@ import (
// // ... access token expires ...
// newInfo, _ := manager.RefreshAccessToken(tokenInfo.RefreshToken)
// Constants for refresh token | 刷新令牌常量
const (
DefaultRefreshTTL = 30 * 24 * time.Hour // 30 days | 30天
DefaultAccessTTL = 2 * time.Hour // 2 hours | 2小时
RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度
RefreshKeyPrefix = "satoken:refresh:" // Storage key prefix | 存储键前缀
)
// Error variables | 错误变量
var (
ErrInvalidRefreshToken = fmt.Errorf("invalid refresh token")
ErrRefreshTokenExpired = fmt.Errorf("refresh token expired")
ErrInvalidRefreshData = fmt.Errorf("invalid refresh token data")
)
// RefreshTokenInfo refresh token information | 刷新令牌信息
type RefreshTokenInfo struct {
RefreshToken string // Refresh token (long-lived) | 刷新令牌(长期有效)
@@ -35,7 +50,7 @@ type RefreshTokenInfo struct {
ExpireTime int64 // Expiration timestamp | 过期时间戳
}
// RefreshTokenManager refresh token manager | 刷新令牌管理器
// RefreshTokenManager Refresh token manager | 刷新令牌管理器
type RefreshTokenManager struct {
storage adapter.Storage
tokenGen *token.Generator
@@ -43,34 +58,39 @@ type RefreshTokenManager struct {
accessTTL time.Duration // Access token TTL (configurable) | 访问令牌有效期(可配置)
}
// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器
// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器
// cfg: configuration, uses Timeout for access token TTL | 配置,使用Timeout作为访问令牌有效期
func NewRefreshTokenManager(storage adapter.Storage, cfg *config.Config) *RefreshTokenManager {
refreshTTL := 30 * 24 * time.Hour // 30 days | 30天
accessTTL := time.Duration(cfg.Timeout) * time.Second
if accessTTL == 0 {
accessTTL = 2 * time.Hour // Default 2 hours | 默认2小时
accessTTL = DefaultAccessTTL
}
return &RefreshTokenManager{
storage: storage,
tokenGen: token.NewGenerator(cfg),
refreshTTL: refreshTTL,
refreshTTL: DefaultRefreshTTL,
accessTTL: accessTTL,
}
}
// GenerateTokenPair generates access token and refresh token pair | 生成访问令牌和刷新令牌对
// GenerateTokenPair Generates access token and refresh token pair | 生成访问令牌和刷新令牌对
func (rtm *RefreshTokenManager) GenerateTokenPair(loginID, device string) (*RefreshTokenInfo, error) {
accessToken, err := rtm.tokenGen.Generate(loginID, device)
if err != nil {
return nil, err
if loginID == "" {
return nil, fmt.Errorf("loginID cannot be empty")
}
refreshTokenBytes := make([]byte, 32)
// Generate access token | 生成访问令牌
accessToken, err := rtm.tokenGen.Generate(loginID, device)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
// Generate refresh token | 生成刷新令牌
refreshTokenBytes := make([]byte, RefreshTokenLength)
if _, err := rand.Read(refreshTokenBytes); err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
refreshToken := hex.EncodeToString(refreshTokenBytes)
@@ -84,66 +104,95 @@ func (rtm *RefreshTokenManager) GenerateTokenPair(loginID, device string) (*Refr
ExpireTime: now.Add(rtm.refreshTTL).Unix(),
}
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
key := rtm.getRefreshKey(refreshToken)
if err := rtm.storage.Set(key, info, rtm.refreshTTL); err != nil {
return nil, err
return nil, fmt.Errorf("failed to store refresh token: %w", err)
}
return info, nil
}
// RefreshAccessToken generates new access token using refresh token | 使用刷新令牌生成新的访问令牌
// RefreshAccessToken Generates new access token using refresh token | 使用刷新令牌生成新的访问令牌
func (rtm *RefreshTokenManager) RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) {
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
if refreshToken == "" {
return nil, ErrInvalidRefreshToken
}
key := rtm.getRefreshKey(refreshToken)
data, err := rtm.storage.Get(key)
if err != nil {
return nil, fmt.Errorf("invalid refresh token")
if err != nil || data == nil {
return nil, ErrInvalidRefreshToken
}
oldInfo, ok := data.(*RefreshTokenInfo)
if !ok {
return nil, fmt.Errorf("invalid refresh token data")
return nil, ErrInvalidRefreshData
}
// Check expiration | 检查是否过期
if time.Now().Unix() > oldInfo.ExpireTime {
rtm.storage.Delete(key)
return nil, fmt.Errorf("refresh token expired")
return nil, ErrRefreshTokenExpired
}
// Generate new access token | 生成新的访问令牌
newAccessToken, err := rtm.tokenGen.Generate(oldInfo.LoginID, oldInfo.Device)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate new access token: %w", err)
}
oldInfo.AccessToken = newAccessToken
// Update storage | 更新存储
if err := rtm.storage.Set(key, oldInfo, rtm.refreshTTL); err != nil {
return nil, err
return nil, fmt.Errorf("failed to update refresh token: %w", err)
}
return oldInfo, nil
}
// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌
// RevokeRefreshToken Revokes a refresh token | 撤销刷新令牌
func (rtm *RefreshTokenManager) RevokeRefreshToken(refreshToken string) error {
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
if refreshToken == "" {
return nil
}
key := rtm.getRefreshKey(refreshToken)
return rtm.storage.Delete(key)
}
// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息
// GetRefreshTokenInfo Gets refresh token information | 获取刷新令牌信息
func (rtm *RefreshTokenManager) GetRefreshTokenInfo(refreshToken string) (*RefreshTokenInfo, error) {
key := fmt.Sprintf("satoken:refresh:%s", refreshToken)
if refreshToken == "" {
return nil, ErrInvalidRefreshToken
}
key := rtm.getRefreshKey(refreshToken)
data, err := rtm.storage.Get(key)
if err != nil {
return nil, err
if err != nil || data == nil {
return nil, ErrInvalidRefreshToken
}
info, ok := data.(*RefreshTokenInfo)
if !ok {
return nil, fmt.Errorf("invalid refresh token data")
return nil, ErrInvalidRefreshData
}
return info, nil
}
// IsValid Checks if refresh token is valid | 检查刷新令牌是否有效
func (rtm *RefreshTokenManager) IsValid(refreshToken string) bool {
info, err := rtm.GetRefreshTokenInfo(refreshToken)
if err != nil {
return false
}
return time.Now().Unix() <= info.ExpireTime
}
// getRefreshKey Gets storage key for refresh token | 获取刷新令牌的存储键
func (rtm *RefreshTokenManager) getRefreshKey(refreshToken string) string {
return RefreshKeyPrefix + refreshToken
}
+69 -26
View File
@@ -9,29 +9,46 @@ import (
"github.com/click33/sa-token-go/core/adapter"
)
// Session session object for storing user data | 会话对象,用于存储用户数据
// Constants for session keys | Session键常量
const (
SessionKeyPrefix = "session:" // Storage key prefix | 存储键前缀
)
// Error variables | 错误变量
var (
ErrSessionNotFound = fmt.Errorf("session not found")
ErrInvalidSessionData = fmt.Errorf("invalid session data")
)
// Session Session object for storing user data | 会话对象,用于存储用户数据
type Session struct {
ID string `json:"id"` // Session ID | Session标识
CreateTime int64 `json:"createTime"` // Creation time | 创建时间
Data map[string]interface{} `json:"data"` // Session data | 数据
mu sync.RWMutex `json:"-"` // Read-write lock | 读写锁
storage adapter.Storage `json:"-"` // Storage backend | 存储
prefix string `json:"-"` // Key prefix | 键前缀
ID string `json:"id"` // Session ID | Session标识
CreateTime int64 `json:"createTime"` // Creation time | 创建时间
Data map[string]any `json:"data"` // Session data | 数据
mu sync.RWMutex `json:"-"` // Read-write lock | 读写锁
storage adapter.Storage `json:"-"` // Storage backend | 存储
prefix string `json:"-"` // Key prefix | 键前缀
}
// NewSession creates a new session | 创建新的Session
// NewSession Creates a new session | 创建新的Session
func NewSession(id string, storage adapter.Storage, prefix string) *Session {
return &Session{
ID: id,
CreateTime: time.Now().Unix(),
Data: make(map[string]interface{}),
Data: make(map[string]any),
storage: storage,
prefix: prefix,
}
}
// Set sets value | 设置值
func (s *Session) Set(key string, value interface{}) error {
// ============ Data Operations | 数据操作 ============
// Set Sets value | 设置值
func (s *Session) Set(key string, value any) error {
if key == "" {
return fmt.Errorf("key cannot be empty")
}
s.mu.Lock()
defer s.mu.Unlock()
@@ -39,8 +56,8 @@ func (s *Session) Set(key string, value interface{}) error {
return s.save()
}
// Get gets value | 获取值
func (s *Session) Get(key string) (interface{}, bool) {
// Get Gets value | 获取值
func (s *Session) Get(key string) (any, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -116,16 +133,16 @@ func (s *Session) Delete(key string) error {
return s.save()
}
// Clear 清空所有数据
// Clear Clears all data | 清空所有数据
func (s *Session) Clear() error {
s.mu.Lock()
defer s.mu.Unlock()
s.Data = make(map[string]interface{})
s.Data = make(map[string]any)
return s.save()
}
// Keys 获取所有键
// Keys Gets all keys | 获取所有键
func (s *Session) Keys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -137,7 +154,7 @@ func (s *Session) Keys() []string {
return keys
}
// Size 获取数据数量
// Size Gets data count | 获取数据数量
func (s *Session) Size() int {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -145,32 +162,55 @@ func (s *Session) Size() int {
return len(s.Data)
}
// save 保存到存储
// IsEmpty Checks if session has no data | 检查Session是否为空
func (s *Session) IsEmpty() bool {
return s.Size() == 0
}
// ============ Internal Methods | 内部方法 ============
// save Saves session to storage | 保存到存储
func (s *Session) save() error {
data, err := json.Marshal(s)
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
key := s.prefix + "session:" + s.ID
key := s.getStorageKey()
return s.storage.Set(key, string(data), 0)
}
// Load 从存储加载
// getStorageKey Gets storage key for this session | 获取Session的存储键
func (s *Session) getStorageKey() string {
return s.prefix + SessionKeyPrefix + s.ID
}
// ============ Static Methods | 静态方法 ============
// Load Loads session from storage | 从存储加载
func Load(id string, storage adapter.Storage, prefix string) (*Session, error) {
key := prefix + "session:" + id
if id == "" {
return nil, fmt.Errorf("session id cannot be empty")
}
key := prefix + SessionKeyPrefix + id
data, err := storage.Get(key)
if err != nil {
return nil, err
}
if data == nil {
return nil, fmt.Errorf("session not found")
return nil, ErrSessionNotFound
}
dataStr, ok := data.(string)
if !ok {
return nil, ErrInvalidSessionData
}
var session Session
if err := json.Unmarshal([]byte(data.(string)), &session); err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
if err := json.Unmarshal([]byte(dataStr), &session); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidSessionData, err)
}
session.storage = storage
@@ -178,8 +218,11 @@ func Load(id string, storage adapter.Storage, prefix string) (*Session, error) {
return &session, nil
}
// Destroy 销毁Session
// Destroy Destroys session | 销毁Session
func (s *Session) Destroy() error {
key := s.prefix + "session:" + s.ID
s.mu.Lock()
defer s.mu.Unlock()
key := s.getStorageKey()
return s.storage.Delete(key)
}
+115 -47
View File
@@ -3,36 +3,61 @@ package token
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"math/big"
"time"
"github.com/click33/sa-token-go/core/config"
"github.com/click33/sa-token-go/core/utils"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// Constants for token generation | Token生成常量
const (
DefaultJWTSecret = "default-secret-key" // Should be overridden in production | 生产环境应覆盖
TikTokenLength = 11 // TikTok-style short ID length | Tik风格短ID长度
TikCharset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
HashRandomBytesLen = 16 // Random bytes length for hash token | 哈希Token的随机字节长度
TimestampRandomLen = 8 // Random bytes length for timestamp token | 时间戳Token的随机字节长度
DefaultSimpleLength = 16 // Default simple token length | 默认简单Token长度
)
// Error variables | 错误变量
var (
ErrInvalidToken = fmt.Errorf("invalid token")
ErrUnexpectedSigningMethod = fmt.Errorf("unexpected signing method")
)
// Generator Token generator | Token生成器
type Generator struct {
config *config.Config
}
// NewGenerator creates a new token generator | 创建新的Token生成器
// NewGenerator Creates a new token generator | 创建新的Token生成器
func NewGenerator(cfg *config.Config) *Generator {
if cfg == nil {
cfg = config.DefaultConfig()
}
return &Generator{
config: cfg,
}
}
// Generate generates token based on configured style | 根据配置的风格生成Token
// ============ Public Methods | 公共方法 ============
// Generate Generates token based on configured style | 根据配置的风格生成Token
func (g *Generator) Generate(loginID string, device string) (string, error) {
if loginID == "" {
return "", fmt.Errorf("loginID cannot be empty")
}
switch g.config.TokenStyle {
case config.TokenStyleUUID:
return g.generateUUID()
case config.TokenStyleSimple:
return g.generateSimple(16)
return g.generateSimple(DefaultSimpleLength)
case config.TokenStyleRandom32:
return g.generateSimple(32)
case config.TokenStyleRandom64:
@@ -52,21 +77,31 @@ func (g *Generator) Generate(loginID string, device string) (string, error) {
}
}
// generateUUID generates UUID token | 生成UUID Token
// ============ Token Generation Methods | Token生成方法 ============
// generateUUID Generates UUID token | 生成UUID Token
func (g *Generator) generateUUID() (string, error) {
return uuid.New().String(), nil
}
// generateSimple generates simple random string token | 生成简单随机字符串Token
func (g *Generator) generateSimple(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
u, err := uuid.NewRandom()
if err != nil {
return "", fmt.Errorf("failed to generate UUID: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
return u.String(), nil
}
// generateJWT generates JWT token | 生成JWT Token
// generateSimple Generates simple random string token | 生成简单随机字符串Token
func (g *Generator) generateSimple(length int) (string, error) {
if length <= 0 {
length = DefaultSimpleLength
}
token := utils.RandomString(length)
if token == "" {
return "", fmt.Errorf("failed to generate random string")
}
return token, nil
}
// generateJWT Generates JWT token | 生成JWT Token
func (g *Generator) generateJWT(loginID string, device string) (string, error) {
now := time.Now()
claims := jwt.MapClaims{
@@ -75,71 +110,105 @@ func (g *Generator) generateJWT(loginID string, device string) (string, error) {
"iat": now.Unix(),
}
// Add expiration if timeout is configured | 如果配置了超时时间则添加过期时间
if g.config.Timeout > 0 {
claims["exp"] = now.Add(time.Duration(g.config.Timeout) * time.Second).Unix()
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
secretKey := g.config.JwtSecretKey
if secretKey == "" {
secretKey = "default-secret-key"
secretKey := g.getJWTSecret()
signedToken, err := token.SignedString([]byte(secretKey))
if err != nil {
return "", fmt.Errorf("failed to sign JWT token: %w", err)
}
return token.SignedString([]byte(secretKey))
return signedToken, nil
}
// ParseJWT parses JWT token and returns claims | 解析JWT Token并返回声明
// getJWTSecret Gets JWT secret key with fallback | 获取JWT密钥(带默认值)
func (g *Generator) getJWTSecret() string {
if g.config.JwtSecretKey != "" {
return g.config.JwtSecretKey
}
return DefaultJWTSecret
}
// ============ JWT Helper Methods | JWT辅助方法 ============
// ParseJWT Parses JWT token and returns claims | 解析JWT Token并返回声明
func (g *Generator) ParseJWT(tokenStr string) (jwt.MapClaims, error) {
secretKey := g.config.JwtSecretKey
if secretKey == "" {
secretKey = "default-secret-key"
if tokenStr == "" {
return nil, fmt.Errorf("token string cannot be empty")
}
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
secretKey := g.getJWTSecret()
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) {
// Verify signing method | 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
return nil, fmt.Errorf("%w: %v", ErrUnexpectedSigningMethod, token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
return nil, ErrInvalidToken
}
// ValidateJWT validates JWT token | 验证JWT Token
// ValidateJWT Validates JWT token | 验证JWT Token
func (g *Generator) ValidateJWT(tokenStr string) error {
_, err := g.ParseJWT(tokenStr)
return err
}
// generateHash generates SHA256 hash-based token | 生成SHA256哈希风格Token
func (g *Generator) generateHash(loginID string, device string) (string, error) {
// Combine loginID, device, timestamp and random bytes
// 组合 loginID、device、时间戳和随机字节
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
// GetLoginIDFromJWT Extracts login ID from JWT token | 从JWT Token中提取登录ID
func (g *Generator) GetLoginIDFromJWT(tokenStr string) (string, error) {
claims, err := g.ParseJWT(tokenStr)
if err != nil {
return "", err
}
data := fmt.Sprintf("%s:%s:%d:%s", loginID, device, time.Now().UnixNano(), hex.EncodeToString(randomBytes))
loginID, ok := claims["loginId"].(string)
if !ok {
return "", fmt.Errorf("loginId not found in token claims")
}
return loginID, nil
}
// generateHash Generates SHA256 hash-based token | 生成SHA256哈希风格Token
func (g *Generator) generateHash(loginID string, device string) (string, error) {
// Combine loginID, device, timestamp and random bytes | 组合 loginID、device、时间戳和随机字节
randomBytes := make([]byte, HashRandomBytesLen)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Create hash input | 创建哈希输入
data := fmt.Sprintf("%s:%s:%d:%s",
loginID,
device,
time.Now().UnixNano(),
hex.EncodeToString(randomBytes))
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:]), nil
}
// generateTimestamp generates timestamp-based token | 生成时间戳风格Token
// generateTimestamp Generates timestamp-based token | 生成时间戳风格Token
func (g *Generator) generateTimestamp(loginID string, device string) (string, error) {
// Format: timestamp_loginID_random
// 格式:时间戳_loginID_随机数
randomBytes := make([]byte, 8)
// Format: timestamp_loginID_random | 格式:时间戳_loginID_随机数
randomBytes := make([]byte, TimestampRandomLen)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
timestamp := time.Now().UnixMilli()
@@ -147,18 +216,17 @@ func (g *Generator) generateTimestamp(loginID string, device string) (string, er
return fmt.Sprintf("%d_%s_%s", timestamp, loginID, random), nil
}
// generateTik generates short ID style token (like TikTok) | 生成Tik风格短ID Token(类似抖音)
// generateTik Generates short ID style token (like TikTok) | 生成Tik风格短ID Token(类似抖音)
func (g *Generator) generateTik() (string, error) {
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
const length = 11 // TikTok-style short ID length | 抖音风格短ID长度
result := make([]byte, TikTokenLength)
charsetLen := int64(len(TikCharset))
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
num, err := rand.Int(rand.Reader, big.NewInt(charsetLen))
if err != nil {
return "", err
return "", fmt.Errorf("failed to generate random number: %w", err)
}
result[i] = charset[num.Int64()]
result[i] = TikCharset[num.Int64()]
}
return string(result), nil
+408 -37
View File
@@ -2,19 +2,100 @@ package utils
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"math/big"
"reflect"
"strconv"
"strings"
"time"
)
// Constants for time durations | 时间常量
const (
Second = 1
Minute = 60 * Second
Hour = 60 * Minute
Day = 24 * Hour
Week = 7 * Day
)
// Constants for string operations | 字符串操作常量
const (
DefaultSeparator = ","
WildcardChar = "*"
)
// ============ Random Generation | 随机生成 ============
// RandomString generates random string of specified length | 生成指定长度的随机字符串
func RandomString(length int) string {
bytes := make([]byte, length)
if length <= 0 {
return ""
}
// Calculate required byte length (base64 expands by ~33%)
byteLen := (length * 3) / 4
if byteLen < length {
byteLen = length
}
bytes := make([]byte, byteLen)
if _, err := rand.Read(bytes); err != nil {
return ""
}
return base64.URLEncoding.EncodeToString(bytes)[:length]
encoded := base64.URLEncoding.EncodeToString(bytes)
// Remove padding and trim to exact length
encoded = strings.TrimRight(encoded, "=")
if len(encoded) > length {
return encoded[:length]
}
return encoded
}
// RandomNumericString generates random numeric string | 生成随机数字字符串
func RandomNumericString(length int) string {
if length <= 0 {
return ""
}
const digits = "0123456789"
result := make([]byte, length)
max := big.NewInt(int64(len(digits)))
for i := 0; i < length; i++ {
n, err := rand.Int(rand.Reader, max)
if err != nil {
return ""
}
result[i] = digits[n.Int64()]
}
return string(result)
}
// RandomAlphanumeric generates random alphanumeric string | 生成随机字母数字字符串
func RandomAlphanumeric(length int) string {
if length <= 0 {
return ""
}
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
result := make([]byte, length)
max := big.NewInt(int64(len(chars)))
for i := 0; i < length; i++ {
n, err := rand.Int(rand.Reader, max)
if err != nil {
return ""
}
result[i] = chars[n.Int64()]
}
return string(result)
}
// IsEmpty checks if string is empty | 检查字符串是否为空
@@ -47,7 +128,7 @@ func ContainsString(slice []string, item string) bool {
// RemoveString removes item from string slice | 从字符串数组中移除指定字符串
func RemoveString(slice []string, item string) []string {
result := make([]string, 0)
result := make([]string, 0, len(slice))
for _, s := range slice {
if s != item {
result = append(result, s)
@@ -58,8 +139,12 @@ func RemoveString(slice []string, item string) []string {
// UniqueStrings removes duplicates from string slice | 字符串数组去重
func UniqueStrings(slice []string) []string {
seen := make(map[string]bool)
result := make([]string, 0)
if len(slice) == 0 {
return []string{}
}
seen := make(map[string]bool, len(slice))
result := make([]string, 0, len(slice))
for _, s := range slice {
if !seen[s] {
seen[s] = true
@@ -69,19 +154,53 @@ func UniqueStrings(slice []string) []string {
return result
}
// MergeStrings 合并多个字符串数组并去重
// FilterStrings filters string slice by predicate | 根据条件过滤字符串数组
func FilterStrings(slice []string, predicate func(string) bool) []string {
result := make([]string, 0, len(slice))
for _, s := range slice {
if predicate(s) {
result = append(result, s)
}
}
return result
}
// MapStrings applies function to each string in slice | 对数组中每个字符串应用函数
func MapStrings(slice []string, mapper func(string) string) []string {
result := make([]string, len(slice))
for i, s := range slice {
result[i] = mapper(s)
}
return result
}
// MergeStrings Merges multiple string slices and removes duplicates | 合并多个字符串数组并去重
func MergeStrings(slices ...[]string) []string {
result := make([]string, 0)
if len(slices) == 0 {
return []string{}
}
// Pre-calculate total capacity
totalLen := 0
for _, slice := range slices {
totalLen += len(slice)
}
result := make([]string, 0, totalLen)
for _, slice := range slices {
result = append(result, slice...)
}
return UniqueStrings(result)
}
// SplitAndTrim 分割字符串并去除空格
// SplitAndTrim Splits string and trims whitespace | 分割字符串并去除空格
func SplitAndTrim(s, sep string) []string {
if s == "" {
return []string{}
}
parts := strings.Split(s, sep)
result := make([]string, 0)
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
@@ -91,6 +210,17 @@ func SplitAndTrim(s, sep string) []string {
return result
}
// JoinNonEmpty Joins non-empty strings | 连接非空字符串
func JoinNonEmpty(sep string, strs ...string) string {
nonEmpty := make([]string, 0, len(strs))
for _, s := range strs {
if IsNotEmpty(s) {
nonEmpty = append(nonEmpty, s)
}
}
return strings.Join(nonEmpty, sep)
}
// GetStructTag 获取结构体字段的标签值
func GetStructTag(field reflect.StructField, tag string) string {
return field.Tag.Get(tag)
@@ -120,18 +250,18 @@ func ParseRoleTag(tag string) []string {
return SplitAndTrim(tag, ",")
}
// MatchPattern 模式匹配(支持通配符*)
// MatchPattern Pattern matching with wildcard support | 模式匹配(支持通配符*
func MatchPattern(pattern, str string) bool {
if pattern == "*" {
if pattern == WildcardChar {
return true
}
if !strings.Contains(pattern, "*") {
if !strings.Contains(pattern, WildcardChar) {
return pattern == str
}
// 简单的通配符匹配
parts := strings.Split(pattern, "*")
// Simple wildcard matching | 简单的通配符匹配
parts := strings.Split(pattern, WildcardChar)
if len(parts) == 2 {
prefix, suffix := parts[0], parts[1]
if prefix != "" && !strings.HasPrefix(str, prefix) {
@@ -143,61 +273,302 @@ func MatchPattern(pattern, str string) bool {
return true
}
return false
// Complex pattern with multiple wildcards | 复杂模式(多个通配符)
pos := 0
for i, part := range parts {
if i == 0 && part != "" {
if !strings.HasPrefix(str, part) {
return false
}
pos += len(part)
continue
}
if i == len(parts)-1 && part != "" {
return strings.HasSuffix(str, part)
}
if part == "" {
continue
}
idx := strings.Index(str[pos:], part)
if idx == -1 {
return false
}
pos += idx + len(part)
}
return true
}
// FormatDuration 格式化时间段(秒)为人类可读格式
// ============ Time & Duration | 时间和时长 ============
// FormatDuration Formats duration in seconds to human-readable format | 格式化时间段(秒)为人类可读格式
func FormatDuration(seconds int64) string {
if seconds < 0 {
return "永久"
}
if seconds < 60 {
if seconds == 0 {
return "0秒"
}
if seconds < Minute {
return fmt.Sprintf("%d秒", seconds)
}
if seconds < 3600 {
minutes := seconds / 60
if seconds < Hour {
minutes := seconds / Minute
return fmt.Sprintf("%d分钟", minutes)
}
if seconds < 86400 {
hours := seconds / 3600
if seconds < Day {
hours := seconds / Hour
return fmt.Sprintf("%d小时", hours)
}
days := seconds / 86400
return fmt.Sprintf("%d天", days)
if seconds < Week {
days := seconds / Day
return fmt.Sprintf("%d天", days)
}
weeks := seconds / Week
return fmt.Sprintf("%d周", weeks)
}
// ParseDuration 解析人类可读的时间段为秒
// ParseDuration Parses human-readable duration to seconds | 解析人类可读的时间段为秒
func ParseDuration(duration string) int64 {
duration = strings.ToLower(strings.TrimSpace(duration))
if strings.HasSuffix(duration, "s") || strings.HasSuffix(duration, "") {
return parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "s"), "秒"))
if duration == "" {
return 0
}
if strings.HasSuffix(duration, "m") || strings.HasSuffix(duration, "分") {
minutes := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "m"), ""))
return minutes * 60
}
if strings.HasSuffix(duration, "h") || strings.HasSuffix(duration, "时") {
hours := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "h"), "时"))
return hours * 3600
// Week | 周
if strings.HasSuffix(duration, "w") || strings.HasSuffix(duration, "") {
weeks := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "w"), "周"))
return weeks * Week
}
// Day | 天
if strings.HasSuffix(duration, "d") || strings.HasSuffix(duration, "天") {
days := parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "d"), "天"))
return days * 86400
return days * Day
}
// Hour | 小时
if strings.HasSuffix(duration, "h") || strings.HasSuffix(duration, "时") || strings.HasSuffix(duration, "小时") {
hours := parseInt64(strings.TrimSuffix(strings.TrimSuffix(strings.TrimSuffix(duration, "h"), "时"), "小时"))
return hours * Hour
}
// Minute | 分钟
if strings.HasSuffix(duration, "m") || strings.HasSuffix(duration, "分") || strings.HasSuffix(duration, "分钟") {
minutes := parseInt64(strings.TrimSuffix(strings.TrimSuffix(strings.TrimSuffix(duration, "m"), "分"), "分钟"))
return minutes * Minute
}
// Second | 秒
if strings.HasSuffix(duration, "s") || strings.HasSuffix(duration, "秒") {
return parseInt64(strings.TrimSuffix(strings.TrimSuffix(duration, "s"), "秒"))
}
return parseInt64(duration)
}
// TimestampToTime Converts Unix timestamp to time.Time | Unix时间戳转time.Time
func TimestampToTime(timestamp int64) time.Time {
return time.Unix(timestamp, 0)
}
// TimeToTimestamp Converts time.Time to Unix timestamp | time.Time转Unix时间戳
func TimeToTimestamp(t time.Time) int64 {
return t.Unix()
}
// ============ Type Conversion | 类型转换 ============
// parseInt64 Parses string to int64 | 将字符串解析为int64
func parseInt64(s string) int64 {
var result int64
fmt.Sscanf(s, "%d", &result)
s = strings.TrimSpace(s)
if s == "" {
return 0
}
result, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0
}
return result
}
// ToInt Converts any to int | 将any转换为int
func ToInt(v any) (int, error) {
switch val := v.(type) {
case int:
return val, nil
case int32:
return int(val), nil
case int64:
return int(val), nil
case float32:
return int(val), nil
case float64:
return int(val), nil
case string:
return strconv.Atoi(val)
default:
return 0, fmt.Errorf("cannot convert %T to int", v)
}
}
// ToInt64 Converts any to int64 | 将any转换为int64
func ToInt64(v any) (int64, error) {
switch val := v.(type) {
case int:
return int64(val), nil
case int32:
return int64(val), nil
case int64:
return val, nil
case float32:
return int64(val), nil
case float64:
return int64(val), nil
case string:
return strconv.ParseInt(val, 10, 64)
default:
return 0, fmt.Errorf("cannot convert %T to int64", v)
}
}
// ToString Converts any to string | 将any转换为string
func ToString(v any) string {
if v == nil {
return ""
}
switch val := v.(type) {
case string:
return val
case []byte:
return string(val)
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", val)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", val)
case float32, float64:
return fmt.Sprintf("%v", val)
case bool:
return strconv.FormatBool(val)
default:
return fmt.Sprintf("%v", val)
}
}
// ToBool Converts any to bool | 将any转换为bool
func ToBool(v any) (bool, error) {
switch val := v.(type) {
case bool:
return val, nil
case string:
return strconv.ParseBool(val)
case int, int8, int16, int32, int64:
return val != 0, nil
default:
return false, fmt.Errorf("cannot convert %T to bool", v)
}
}
// ============ Hash & Encoding | 哈希和编码 ============
// SHA256Hash Generates SHA256 hash of string | 生成字符串的SHA256哈希
func SHA256Hash(s string) string {
hash := sha256.Sum256([]byte(s))
return hex.EncodeToString(hash[:])
}
// Base64Encode Encodes string to base64 | Base64编码
func Base64Encode(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
// Base64Decode Decodes base64 string | Base64解码
func Base64Decode(s string) (string, error) {
data, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return "", err
}
return string(data), nil
}
// Base64URLEncode Encodes string to URL-safe base64 | URL安全的Base64编码
func Base64URLEncode(s string) string {
return base64.URLEncoding.EncodeToString([]byte(s))
}
// Base64URLDecode Decodes URL-safe base64 string | URL安全的Base64解码
func Base64URLDecode(s string) (string, error) {
data, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return "", err
}
return string(data), nil
}
// ============ Validation | 验证 ============
// IsAlphanumeric Checks if string contains only alphanumeric characters | 检查是否只包含字母数字
func IsAlphanumeric(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) {
return false
}
}
return true
}
// IsNumeric Checks if string contains only numbers | 检查是否只包含数字
func IsNumeric(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r < '0' || r > '9' {
return false
}
}
return true
}
// HasLength Checks if string length is within range | 检查字符串长度是否在范围内
func HasLength(s string, min, max int) bool {
length := len(s)
return length >= min && length <= max
}
// ============ Slice Helpers | 切片辅助 ============
// InSlice Checks if value exists in slice | 检查值是否存在于切片中
func InSlice[T comparable](slice []T, val T) bool {
for _, item := range slice {
if item == val {
return true
}
}
return false
}
// UniqueSlice Removes duplicates from slice | 去除切片中的重复元素
func UniqueSlice[T comparable](slice []T) []T {
seen := make(map[T]bool, len(slice))
result := make([]T, 0, len(slice))
for _, item := range slice {
if !seen[item] {
seen[item] = true
result = append(result, item)
}
}
return result
}
+45
View File
@@ -0,0 +1,45 @@
package main
import (
"github.com/click33/sa-token-go/core/banner"
"github.com/click33/sa-token-go/core/config"
)
func main() {
// 1. 打印基础 Banner
banner.Print()
// 2. 打印带完整配置的 Banner
cfg := config.DefaultConfig()
banner.PrintWithConfig(cfg)
// 3. 打印 JWT 配置的 Banner
jwtCfg := &config.Config{
TokenName: "jwt-token",
Timeout: 86400, // 24小时
ActiveTimeout: -1,
IsConcurrent: true,
IsShare: false,
MaxLoginCount: 5,
IsReadBody: false,
IsReadHeader: true,
IsReadCookie: true,
TokenStyle: config.TokenStyleJWT,
DataRefreshPeriod: -1,
TokenSessionCheckLogin: true,
AutoRenew: true,
JwtSecretKey: "my-super-secret-key-123456",
IsLog: true,
IsPrintBanner: true,
CookieConfig: &config.CookieConfig{
Domain: "example.com",
Path: "/api",
Secure: true,
HttpOnly: true,
SameSite: config.SameSiteStrict,
MaxAge: 7200,
},
}
banner.PrintWithConfig(jwtCfg)
}
+18 -5
View File
@@ -19,7 +19,7 @@ var (
// item 存储项
type item struct {
value interface{}
value any
expiration int64 // 过期时间戳(0表示永不过期)
}
@@ -54,7 +54,7 @@ func NewStorageWithCleanupInterval(interval time.Duration) adapter.Storage {
}
// Set 设置键值对
func (s *Storage) Set(key string, value interface{}, expiration time.Duration) error {
func (s *Storage) Set(key string, value any, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -72,7 +72,7 @@ func (s *Storage) Set(key string, value interface{}, expiration time.Duration) e
}
// Get 获取值
func (s *Storage) Get(key string) (interface{}, error) {
func (s *Storage) Get(key string) (any, error) {
now := time.Now().Unix()
s.mu.RLock()
@@ -93,11 +93,13 @@ func (s *Storage) Get(key string) (interface{}, error) {
}
// Delete 删除键
func (s *Storage) Delete(key string) error {
func (s *Storage) Delete(keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, key)
for _, key := range keys {
delete(s.data, key)
}
return nil
}
@@ -194,6 +196,17 @@ func (s *Storage) Clear() error {
return nil
}
// Ping 检查存储可用性
func (s *Storage) Ping() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.closed {
return errors.New("storage is closed")
}
return nil
}
// Close 关闭存储,停止清理协程
func (s *Storage) Close() error {
s.mu.Lock()
+20 -4
View File
@@ -104,14 +104,14 @@ func (s *Storage) getKey(key string) string {
}
// Set 设置键值对
func (s *Storage) Set(key string, value interface{}, expiration time.Duration) error {
func (s *Storage) Set(key string, value any, expiration time.Duration) error {
ctx, cancel := s.withTimeout()
defer cancel()
return s.client.Set(ctx, s.getKey(key), value, expiration).Err()
}
// Get 获取值
func (s *Storage) Get(key string) (interface{}, error) {
func (s *Storage) Get(key string) (any, error) {
ctx, cancel := s.withTimeout()
defer cancel()
val, err := s.client.Get(ctx, s.getKey(key)).Result()
@@ -125,10 +125,19 @@ func (s *Storage) Get(key string) (interface{}, error) {
}
// Delete 删除键
func (s *Storage) Delete(key string) error {
func (s *Storage) Delete(keys ...string) error {
if len(keys) == 0 {
return nil
}
ctx, cancel := s.withTimeout()
defer cancel()
return s.client.Del(ctx, s.getKey(key)).Err()
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = s.getKey(key)
}
return s.client.Del(ctx, fullKeys...).Err()
}
// Exists 检查键是否存在
@@ -215,6 +224,13 @@ func (s *Storage) Clear() error {
return nil
}
// Ping 检查连接
func (s *Storage) Ping() error {
ctx, cancel := s.withTimeout()
defer cancel()
return s.client.Ping(ctx).Err()
}
// Close 关闭连接
func (s *Storage) Close() error {
return s.client.Close()