74 lines
2.3 KiB
Go
74 lines
2.3 KiB
Go
|
|
package mw
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/json"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/logging"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/security"
|
||
|
|
)
|
||
|
|
|
||
|
|
// IPRateLimit returns middleware that rate-limits requests by client IP address.
|
||
|
|
// The IP is extracted from X-Forwarded-For (first value) or RemoteAddr.
|
||
|
|
// When the store returns an error the middleware fails open: the error is logged
|
||
|
|
// and the request is allowed through.
|
||
|
|
func IPRateLimit(store RateLimiterStore, logger logging.Logger) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
key := clientIP(r)
|
||
|
|
ok, err := store.Allow(r.Context(), key)
|
||
|
|
if err != nil {
|
||
|
|
logger.WithContext(r.Context()).Warn("rate_limit: store error, failing open", "err", err.Error())
|
||
|
|
} else if !ok {
|
||
|
|
rateLimitExceeded(w)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// UserRateLimit returns middleware that rate-limits by authenticated user ID.
|
||
|
|
// Falls back to client IP when no [security.Identity] is present in the context.
|
||
|
|
// When the store returns an error the middleware fails open.
|
||
|
|
func UserRateLimit(store RateLimiterStore, logger logging.Logger) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
key := clientIP(r)
|
||
|
|
if id, ok := security.FromContext(r.Context()); ok && id.UID != "" {
|
||
|
|
key = id.UID
|
||
|
|
}
|
||
|
|
ok, err := store.Allow(r.Context(), key)
|
||
|
|
if err != nil {
|
||
|
|
logger.WithContext(r.Context()).Warn("rate_limit: store error, failing open", "err", err.Error())
|
||
|
|
} else if !ok {
|
||
|
|
rateLimitExceeded(w)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func clientIP(r *http.Request) string {
|
||
|
|
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||
|
|
parts := strings.SplitN(fwd, ",", 2)
|
||
|
|
return strings.TrimSpace(parts[0])
|
||
|
|
}
|
||
|
|
addr := r.RemoteAddr
|
||
|
|
if i := strings.LastIndex(addr, ":"); i >= 0 {
|
||
|
|
return addr[:i]
|
||
|
|
}
|
||
|
|
return addr
|
||
|
|
}
|
||
|
|
|
||
|
|
func rateLimitExceeded(w http.ResponseWriter) {
|
||
|
|
w.Header().Set("Content-Type", "application/json")
|
||
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
||
|
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||
|
|
"code": "RESOURCE_EXHAUSTED",
|
||
|
|
"message": "too many requests",
|
||
|
|
})
|
||
|
|
}
|