Files
httpauth-jwt/refresh.go

73 lines
2.3 KiB
Go
Raw Normal View History

package jwtauth
import (
"context"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
// ErrTokenRevoked is returned by RefreshTokenPair when the JTI is on the blacklist.
// Callers should respond with 401 and prompt re-authentication.
var ErrTokenRevoked = errors.New("token revoked")
// Blacklist records and checks revoked refresh token JTIs.
// Implementations are typically backed by Valkey or Redis.
// TTL on Revoke should match the token's remaining lifetime so entries expire naturally.
type Blacklist interface {
IsRevoked(ctx context.Context, jti string) (bool, error)
Revoke(ctx context.Context, jti string, ttl time.Duration) error
}
// RefreshTokenPair validates refreshToken, checks the blacklist, revokes the old
// JTI, and issues a new token pair for the same uid.
//
// customClaims are merged into the new access token — typically the caller
// re-fetches fresh permission masks here so the new token reflects any role changes
// made since the previous issue.
//
// Returns ErrTokenRevoked if the JTI is already on the blacklist (replay attack or
// re-use after rotation). Any other error indicates an infrastructure or token fault.
func RefreshTokenPair(ctx context.Context, signer Signer, refreshToken string, bl Blacklist, cfg TokenConfig, customClaims map[string]any) (TokenPair, error) {
token, err := signer.Verify(refreshToken)
if err != nil {
return TokenPair{}, fmt.Errorf("invalid refresh token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return TokenPair{}, fmt.Errorf("unexpected claims type in refresh token")
}
jti, _ := claims["jti"].(string)
uid, _ := claims["sub"].(string)
if jti == "" || uid == "" {
return TokenPair{}, fmt.Errorf("missing required claims in refresh token")
}
revoked, err := bl.IsRevoked(ctx, jti)
if err != nil {
return TokenPair{}, fmt.Errorf("blacklist check: %w", err)
}
if revoked {
return TokenPair{}, ErrTokenRevoked
}
expTime, err := claims.GetExpirationTime()
if err != nil || expTime == nil {
return TokenPair{}, fmt.Errorf("invalid expiration in refresh token")
}
remaining := time.Until(expTime.Time)
if remaining < time.Second {
remaining = time.Second
}
if err := bl.Revoke(ctx, jti, remaining); err != nil {
return TokenPair{}, fmt.Errorf("revoke old token: %w", err)
}
return IssueTokenPair(signer, uid, customClaims, cfg)
}