tencent_ocr/.history/pkg/middleware/ratelimit_20250115162845.go
2025-01-15 16:59:27 +08:00

76 lines
1.4 KiB
Go

package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
r rate.Limit
b int
}
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
r: r,
b: b,
}
}
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
limiter, exists := rl.limiters[key]
if !exists {
limiter = rate.NewLimiter(rl.r, rl.b)
rl.limiters[key] = limiter
}
return limiter
}
// RateLimit creates a middleware that limits request rates per API key
func RateLimit(r rate.Limit, burst int) gin.HandlerFunc {
rl := NewRateLimiter(r, burst)
return func(c *gin.Context) {
key := c.GetHeader(APIKeyHeader)
if key == "" {
key = c.ClientIP()
}
limiter := rl.getLimiter(key)
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"success": false,
"error": "Rate limit exceeded",
})
c.Abort()
return
}
c.Next()
}
}
// CleanupTask periodically removes unused limiters
func (rl *RateLimiter) CleanupTask(interval time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
rl.mu.Lock()
for key := range rl.limiters {
delete(rl.limiters, key)
}
rl.mu.Unlock()
}
}()
}