package middleware import ( "context" "fmt" "github.com/gin-gonic/gin" "message-pusher/common" "net/http" "time" ) var timeFormat = "2006-01-02T15:04:05.000Z" var inMemoryRateLimiter common.InMemoryRateLimiter func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { ctx := context.Background() rdb := common.RDB key := "rateLimit:" + mark + c.ClientIP() listLength, err := rdb.LLen(ctx, key).Result() if err != nil { fmt.Println(err.Error()) c.Status(http.StatusInternalServerError) c.Abort() return } if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) } else { oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { fmt.Println(err) c.Status(http.StatusInternalServerError) c.Abort() return } nowTimeStr := time.Now().Format(timeFormat) nowTime, err := time.Parse(timeFormat, nowTimeStr) if err != nil { fmt.Println(err) c.Status(http.StatusInternalServerError) c.Abort() return } // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) c.Status(http.StatusTooManyRequests) c.Abort() return } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) } } } func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { key := mark + c.ClientIP() if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } } func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { if common.RedisEnabled { return func(c *gin.Context) { redisRateLimiter(c, maxRequestNum, duration, mark) } } else { // It's safe to call multi times. inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) return func(c *gin.Context) { memoryRateLimiter(c, maxRequestNum, duration, mark) } } } func GlobalWebRateLimit() func(c *gin.Context) { return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") } func DownloadRateLimit() func(c *gin.Context) { return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") } func UploadRateLimit() func(c *gin.Context) { return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") }