package mw import ( "encoding/json" "net/http" "strings" "code.nochebuena.dev/einherjar/contracts/logging" "code.nochebuena.dev/einherjar/contracts/security" ) // IPRateLimit returns middleware that rate-limits requests by client IP address. // The IP is extracted from X-Forwarded-For (first value) or RemoteAddr. // When the store returns an error the middleware fails open: the error is logged // and the request is allowed through. func IPRateLimit(store RateLimiterStore, logger logging.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := clientIP(r) ok, err := store.Allow(r.Context(), key) if err != nil { logger.WithContext(r.Context()).Warn("rate_limit: store error, failing open", "err", err.Error()) } else if !ok { rateLimitExceeded(w) return } next.ServeHTTP(w, r) }) } } // UserRateLimit returns middleware that rate-limits by authenticated user ID. // Falls back to client IP when no [security.Identity] is present in the context. // When the store returns an error the middleware fails open. func UserRateLimit(store RateLimiterStore, logger logging.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := clientIP(r) if id, ok := security.FromContext(r.Context()); ok && id.UID != "" { key = id.UID } ok, err := store.Allow(r.Context(), key) if err != nil { logger.WithContext(r.Context()).Warn("rate_limit: store error, failing open", "err", err.Error()) } else if !ok { rateLimitExceeded(w) return } next.ServeHTTP(w, r) }) } } func clientIP(r *http.Request) string { if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { parts := strings.SplitN(fwd, ",", 2) return strings.TrimSpace(parts[0]) } addr := r.RemoteAddr if i := strings.LastIndex(addr, ":"); i >= 0 { return addr[:i] } return addr } func rateLimitExceeded(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) _ = json.NewEncoder(w).Encode(map[string]string{ "code": "RESOURCE_EXHAUSTED", "message": "too many requests", }) }