Files

85 lines
2.1 KiB
Go
Raw Permalink Normal View History

package httpauthjwt
import (
"crypto/rand"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
// TokenConfig configures token lifetimes and the issuer claim.
type TokenConfig struct {
AccessTTL time.Duration
RefreshTTL time.Duration
Issuer string
}
// TokenPair holds an access token and a refresh token.
type TokenPair struct {
AccessToken string
RefreshToken string
ExpiresIn int64 // seconds until the access token expires
}
// IssueTokenPair signs a new access + refresh token pair for uid.
//
// customClaims are merged into the access token at the top level. Use this to
// embed per-resource permission masks so ClaimsPermissionProvider can read them
// without a database call:
//
// customClaims := map[string]any{
// "permisos": map[string]any{"usuarios": int64(515), "roles": int64(30)},
// }
//
// The refresh token contains only sub, iss, iat, exp, jti, and fam (token family).
// It carries no permission data — callers re-fetch fresh claims on each rotation.
func IssueTokenPair(signer Signer, uid string, customClaims map[string]any, cfg TokenConfig) (TokenPair, error) {
now := time.Now()
accessClaims := jwt.MapClaims{
"sub": uid,
"iss": cfg.Issuer,
"iat": jwt.NewNumericDate(now),
"exp": jwt.NewNumericDate(now.Add(cfg.AccessTTL)),
"jti": newJTI(),
}
for k, v := range customClaims {
accessClaims[k] = v
}
accessToken, err := signer.Sign(accessClaims)
if err != nil {
return TokenPair{}, fmt.Errorf("sign access token: %w", err)
}
refreshClaims := jwt.MapClaims{
"sub": uid,
"iss": cfg.Issuer,
"iat": jwt.NewNumericDate(now),
"exp": jwt.NewNumericDate(now.Add(cfg.RefreshTTL)),
"jti": newJTI(),
"fam": newJTI(),
}
refreshToken, err := signer.Sign(refreshClaims)
if err != nil {
return TokenPair{}, fmt.Errorf("sign refresh token: %w", err)
}
return TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: int64(cfg.AccessTTL.Seconds()),
}, nil
}
func newJTI() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
}