package jwtauth import ( "context" "errors" "fmt" "time" "github.com/golang-jwt/jwt/v5" ) // ErrTokenRevoked is returned by RefreshTokenPair when the JTI is on the blacklist. // Callers should respond with 401 and prompt re-authentication. var ErrTokenRevoked = errors.New("token revoked") // Blacklist records and checks revoked refresh token JTIs. // Implementations are typically backed by Valkey or Redis. // TTL on Revoke should match the token's remaining lifetime so entries expire naturally. type Blacklist interface { IsRevoked(ctx context.Context, jti string) (bool, error) Revoke(ctx context.Context, jti string, ttl time.Duration) error } // RefreshTokenPair validates refreshToken, checks the blacklist, revokes the old // JTI, and issues a new token pair for the same uid. // // customClaims are merged into the new access token — typically the caller // re-fetches fresh permission masks here so the new token reflects any role changes // made since the previous issue. // // Returns ErrTokenRevoked if the JTI is already on the blacklist (replay attack or // re-use after rotation). Any other error indicates an infrastructure or token fault. func RefreshTokenPair(ctx context.Context, signer Signer, refreshToken string, bl Blacklist, cfg TokenConfig, customClaims map[string]any) (TokenPair, error) { token, err := signer.Verify(refreshToken) if err != nil { return TokenPair{}, fmt.Errorf("invalid refresh token: %w", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return TokenPair{}, fmt.Errorf("unexpected claims type in refresh token") } jti, _ := claims["jti"].(string) uid, _ := claims["sub"].(string) if jti == "" || uid == "" { return TokenPair{}, fmt.Errorf("missing required claims in refresh token") } revoked, err := bl.IsRevoked(ctx, jti) if err != nil { return TokenPair{}, fmt.Errorf("blacklist check: %w", err) } if revoked { return TokenPair{}, ErrTokenRevoked } expTime, err := claims.GetExpirationTime() if err != nil || expTime == nil { return TokenPair{}, fmt.Errorf("invalid expiration in refresh token") } remaining := time.Until(expTime.Time) if remaining < time.Second { remaining = time.Second } if err := bl.Revoke(ctx, jti, remaining); err != nil { return TokenPair{}, fmt.Errorf("revoke old token: %w", err) } return IssueTokenPair(signer, uid, customClaims, cfg) }