Files
auth-jwt/compliance_test.go

598 lines
18 KiB
Go
Raw Normal View History

feat(auth-jwt): initial implementation — JWT lifecycle and AuthMiddleware (v1.0.0) Introduces code.nochebuena.dev/einherjar/auth-jwt — JWT authentication middleware and token lifecycle management for the Einherjar framework. Absorbs httpauth-jwt from micro-lib with three changes: logger parameter on AuthMiddleware, httputil.Error for consistent 401 responses, and added ECDSA support. Signers and Verifiers (one file per implementation — CT-6 compliant): - Verifier interface — Verify(tokenString string) (*jwt.Token, error) - Signer interface — extends Verifier; adds Sign(claims jwt.Claims) (string, error) - signer_hmac.go — NewHMACSigner(secret) → HS256; jwt.WithJSONNumber() on Verify - signer_rsa.go — NewRSASigner(key) + NewRSASignerFromPEM(pem) → RS256 - verifier_rsa.go — NewRSAPublicKeyVerifier + NewRSAPublicKeyVerifierFromPEM → RS256 verify-only - signer_ec.go — NewECSigner(key) + NewECSignerFromPEM(pem) → ES256/384/512; algorithm auto-detected from key curve (P-256→ES256, P-384→ES384, P-521→ES512) - verifier_ec.go — NewECPublicKeyVerifier + NewECPublicKeyVerifierFromPEM → EC verify-only Token lifecycle: - TokenConfig struct — AccessTTL, RefreshTTL, Issuer - TokenPair struct — AccessToken, RefreshToken, ExpiresIn - IssueTokenPair — access + refresh pair; customClaims merged at top level; refresh carries only sub/iss/iat/exp/jti/fam; jwt.WithJSONNumber() preserves int64 bitmasks - Blacklist interface — IsRevoked + Revoke; satisfied by cache-valkey via duck typing - ErrTokenRevoked — errors.New sentinel; errors.Is pattern for replay-attack detection - RefreshTokenPair — verifies token, checks blacklist, revokes old JTI, issues new pair HTTP middleware: - AuthMiddleware(logger, verifier, publicPaths) — verifies Bearer tokens; calls authmw.SetTokenData on success; 401 routed through httputil.Error (Warn level); publicPaths use path.Match wildcards; accepts Verifier (not Signer) to enforce narrowest-interface principle for verify-only services Compliance test (package authjwt_test) enforces CT-6 (≤1 exported TypeSpec/file), compile-time interface satisfaction, and behavioural coverage: HMAC/RSA/EC sign+verify, algorithm mismatch rejection, IssueTokenPair claims/jti/fam, RefreshTokenPair success/ revoked/blacklist-error/custom-claims, MaxInt64 json.Number precision, AuthMiddleware valid/invalid/expired/missing/public-path/wildcard/JSON-body/RSA-verifier/SetTokenData. Depends on auth v1.0.0, contracts v1.0.0, core v1.0.0, web v1.0.0, jwt/v5 v5.2.1. - identifiable.go: package-level Module variable (observability.Identifiable) for version identification — auth-jwt is a function library; not registered with the launcher
2026-05-29 16:13:01 +00:00
package authjwt_test
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"go/ast"
"go/parser"
"go/token"
"math"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
authjwt "code.nochebuena.dev/einherjar/auth-jwt"
"code.nochebuena.dev/einherjar/auth/authmw"
"code.nochebuena.dev/einherjar/contracts/logging"
)
// --- Package-level test keys (generated once to avoid per-test cost) ---
var (
testSecret = []byte("test-secret-key-at-least-32-bytes!")
testHMAC = authjwt.NewHMACSigner(testSecret)
testRSAKey = mustGenerateRSA()
testRSA = authjwt.NewRSASigner(testRSAKey)
testECKey = mustGenerateEC(elliptic.P256())
testEC = authjwt.NewECSigner(testECKey)
)
func mustGenerateRSA() *rsa.PrivateKey {
k, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
return k
}
func mustGenerateEC(curve elliptic.Curve) *ecdsa.PrivateKey {
k, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
panic(err)
}
return k
}
var testCfg = authjwt.TokenConfig{
AccessTTL: time.Minute,
RefreshTTL: 7 * 24 * time.Hour,
Issuer: "test-issuer",
}
// --- Mock types ---
type mockSigner struct{}
func (m *mockSigner) Sign(_ jwt.Claims) (string, error) { return "", nil }
func (m *mockSigner) Verify(_ string) (*jwt.Token, error) { return nil, nil }
type mockVerifier struct{}
func (m *mockVerifier) Verify(_ string) (*jwt.Token, error) { return nil, nil }
type mockBlacklist struct {
revoked map[string]bool
err error
}
func newMockBlacklist() *mockBlacklist {
return &mockBlacklist{revoked: make(map[string]bool)}
}
func (m *mockBlacklist) IsRevoked(_ context.Context, jti string) (bool, error) {
if m.err != nil {
return false, m.err
}
return m.revoked[jti], nil
}
func (m *mockBlacklist) Revoke(_ context.Context, jti string, _ time.Duration) error {
if m.err != nil {
return m.err
}
m.revoked[jti] = true
return nil
}
// Compile-time interface satisfaction checks.
var _ authjwt.Signer = (*mockSigner)(nil)
var _ authjwt.Verifier = (*mockVerifier)(nil)
var _ authjwt.Blacklist = (*mockBlacklist)(nil)
// --- nopLogger ---
type nopLogger struct{}
func (nopLogger) Debug(msg string, args ...any) {}
func (nopLogger) Info(msg string, args ...any) {}
func (nopLogger) Warn(msg string, args ...any) {}
func (nopLogger) Error(msg string, err error, args ...any) {}
func (nopLogger) With(args ...any) logging.Logger { return nopLogger{} }
func (nopLogger) WithContext(_ context.Context) logging.Logger { return nopLogger{} }
var _ logging.Logger = nopLogger{}
// --- CT-6 ---
func TestAtMostOneExportedTypePerFile(t *testing.T) {
fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, ".", nil, 0)
if err != nil {
t.Fatalf("parse: %v", err)
}
for _, pkg := range pkgs {
if strings.HasSuffix(pkg.Name, "_test") {
continue
}
for fname, f := range pkg.Files {
base := filepath.Base(fname)
if strings.HasSuffix(base, "_test.go") {
continue
}
count := 0
for _, decl := range f.Decls {
gd, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
for _, spec := range gd.Specs {
if ts, ok := spec.(*ast.TypeSpec); ok && ts.Name.IsExported() {
count++
}
}
}
if count > 1 {
t.Errorf("%s: %d exported TypeSpecs (max 1)", base, count)
}
}
}
}
// --- Signer / Verifier ---
func TestHMACSigner_SignAndVerify(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, err := testHMAC.Sign(claims)
if err != nil {
t.Fatalf("Sign: %v", err)
}
parsed, err := testHMAC.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := parsed.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
}
func TestHMACSigner_TamperedToken(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := testHMAC.Verify(tok + "tampered")
if err == nil {
t.Error("expected error for tampered token")
}
}
func TestHMACSigner_WrongSecret(t *testing.T) {
other := authjwt.NewHMACSigner([]byte("completely-different-secret-key!"))
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := other.Verify(tok)
if err == nil {
t.Error("expected error for wrong secret")
}
}
func TestHMACSigner_AlgMismatch(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testRSA.Sign(claims)
_, err := testHMAC.Verify(tok)
if err == nil {
t.Error("expected error: RSA token verified with HMAC")
}
}
func TestRSASigner_SignAndVerify(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, err := testRSA.Sign(claims)
if err != nil {
t.Fatalf("Sign: %v", err)
}
parsed, err := testRSA.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := parsed.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
}
func TestRSAPublicKeyVerifier_VerifiesTokenFromSigner(t *testing.T) {
verifier := authjwt.NewRSAPublicKeyVerifier(&testRSAKey.PublicKey)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testRSA.Sign(claims)
_, err := verifier.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
}
func TestRSAPublicKeyVerifier_RejectsHMACToken(t *testing.T) {
verifier := authjwt.NewRSAPublicKeyVerifier(&testRSAKey.PublicKey)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := verifier.Verify(tok)
if err == nil {
t.Error("expected error: HMAC token verified with RSA public key")
}
}
func TestECSigner_SignAndVerify(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, err := testEC.Sign(claims)
if err != nil {
t.Fatalf("Sign: %v", err)
}
parsed, err := testEC.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := parsed.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
}
func TestECSigner_P384(t *testing.T) {
key := mustGenerateEC(elliptic.P384())
signer := authjwt.NewECSigner(key)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, err := signer.Sign(claims)
if err != nil {
t.Fatalf("Sign P-384: %v", err)
}
_, err = signer.Verify(tok)
if err != nil {
t.Fatalf("Verify P-384: %v", err)
}
}
func TestECPublicKeyVerifier_VerifiesTokenFromSigner(t *testing.T) {
verifier := authjwt.NewECPublicKeyVerifier(&testECKey.PublicKey)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testEC.Sign(claims)
_, err := verifier.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
}
func TestECPublicKeyVerifier_RejectsHMACToken(t *testing.T) {
verifier := authjwt.NewECPublicKeyVerifier(&testECKey.PublicKey)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := verifier.Verify(tok)
if err == nil {
t.Error("expected error: HMAC token verified with EC public key")
}
}
// --- IssueTokenPair ---
func TestIssueTokenPair_StandardClaims(t *testing.T) {
pair, err := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
if err != nil {
t.Fatalf("IssueTokenPair: %v", err)
}
if pair.AccessToken == "" || pair.RefreshToken == "" {
t.Error("expected non-empty tokens")
}
if pair.ExpiresIn != int64(testCfg.AccessTTL.Seconds()) {
t.Errorf("want ExpiresIn=%d, got %d", int64(testCfg.AccessTTL.Seconds()), pair.ExpiresIn)
}
tok, err := testHMAC.Verify(pair.AccessToken)
if err != nil {
t.Fatalf("verify access token: %v", err)
}
mc, _ := tok.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
if mc["iss"] != testCfg.Issuer {
t.Errorf("want iss=%s, got %v", testCfg.Issuer, mc["iss"])
}
if mc["jti"] == "" {
t.Error("expected non-empty jti in access token")
}
}
func TestIssueTokenPair_CustomClaims(t *testing.T) {
custom := map[string]any{"permisos": map[string]any{"usuarios": int64(515)}}
pair, err := authjwt.IssueTokenPair(testHMAC, "uid1", custom, testCfg)
if err != nil {
t.Fatalf("IssueTokenPair: %v", err)
}
tok, _ := testHMAC.Verify(pair.AccessToken)
mc, _ := tok.Claims.(jwt.MapClaims)
permisos, ok := mc["permisos"].(map[string]any)
if !ok {
t.Fatalf("permisos claim missing or wrong type: %T", mc["permisos"])
}
n, ok := permisos["usuarios"].(json.Number)
if !ok {
t.Fatalf("want json.Number, got %T", permisos["usuarios"])
}
if v, _ := n.Int64(); v != 515 {
t.Errorf("want usuarios=515, got %d", v)
}
}
func TestIssueTokenPair_UniqueJTIs(t *testing.T) {
p1, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
p2, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
if p1.AccessToken == p2.AccessToken {
t.Error("expected unique access tokens across calls")
}
if p1.RefreshToken == p2.RefreshToken {
t.Error("expected unique refresh tokens across calls")
}
}
func TestIssueTokenPair_RefreshHasFam(t *testing.T) {
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
tok, _ := testHMAC.Verify(pair.RefreshToken)
mc, _ := tok.Claims.(jwt.MapClaims)
if mc["fam"] == "" {
t.Error("expected fam claim in refresh token")
}
}
// --- RefreshTokenPair ---
func TestRefreshTokenPair_Success(t *testing.T) {
bl := newMockBlacklist()
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
newPair, err := authjwt.RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil)
if err != nil {
t.Fatalf("RefreshTokenPair: %v", err)
}
if newPair.AccessToken == "" || newPair.RefreshToken == "" {
t.Error("expected non-empty new token pair")
}
if newPair.RefreshToken == pair.RefreshToken {
t.Error("new refresh token must differ from old")
}
}
func TestRefreshTokenPair_OldTokenRevoked(t *testing.T) {
bl := newMockBlacklist()
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
if _, err := authjwt.RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil); err != nil {
t.Fatalf("first refresh: %v", err)
}
_, err := authjwt.RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil)
if !errors.Is(err, authjwt.ErrTokenRevoked) {
t.Errorf("want ErrTokenRevoked, got %v", err)
}
}
func TestRefreshTokenPair_InvalidToken(t *testing.T) {
bl := newMockBlacklist()
_, err := authjwt.RefreshTokenPair(context.Background(), testHMAC, "not.a.token", bl, testCfg, nil)
if err == nil {
t.Error("expected error for invalid token string")
}
}
func TestRefreshTokenPair_BlacklistCheckError(t *testing.T) {
bl := &mockBlacklist{revoked: make(map[string]bool), err: errors.New("valkey unavailable")}
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
_, err := authjwt.RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil)
if err == nil {
t.Error("expected error when blacklist is unavailable")
}
}
func TestRefreshTokenPair_CustomClaimsInNewToken(t *testing.T) {
bl := newMockBlacklist()
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
freshClaims := map[string]any{"permisos": map[string]any{"usuarios": float64(7)}}
newPair, err := authjwt.RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, freshClaims)
if err != nil {
t.Fatalf("RefreshTokenPair: %v", err)
}
tok, _ := testHMAC.Verify(newPair.AccessToken)
mc, _ := tok.Claims.(jwt.MapClaims)
permisos, ok := mc["permisos"].(map[string]any)
if !ok {
t.Fatalf("permisos missing from new access token")
}
n, ok := permisos["usuarios"].(json.Number)
if !ok {
t.Fatalf("want json.Number, got %T", permisos["usuarios"])
}
if v, _ := n.Int64(); v != 7 {
t.Errorf("want 7, got %d", v)
}
}
// --- Precision ---
func TestVerify_JSONNumberPreservesMaxInt64(t *testing.T) {
custom := map[string]any{"masks": map[string]any{"*": int64(math.MaxInt64)}}
pair, err := authjwt.IssueTokenPair(testHMAC, "uid1", custom, testCfg)
if err != nil {
t.Fatalf("IssueTokenPair: %v", err)
}
tok, err := testHMAC.Verify(pair.AccessToken)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := tok.Claims.(jwt.MapClaims)
masks, ok := mc["masks"].(map[string]any)
if !ok {
t.Fatalf("masks claim missing or wrong type: %T", mc["masks"])
}
n, ok := masks["*"].(json.Number)
if !ok {
t.Fatalf("want json.Number, got %T — jwt.WithJSONNumber() may not be set", masks["*"])
}
got, err := n.Int64()
if err != nil {
t.Fatalf("Int64(): %v", err)
}
if got != math.MaxInt64 {
t.Errorf("want MaxInt64 (%d), got %d — precision lost in float64 round-trip", int64(math.MaxInt64), got)
}
}
// --- AuthMiddleware ---
func TestAuthMiddleware_ValidToken(t *testing.T) {
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, testCfg)
reached := false
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
if !reached {
t.Error("inner handler was not called")
}
}
func TestAuthMiddleware_InvalidToken(t *testing.T) {
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer invalid.token.here")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
expiredCfg := authjwt.TokenConfig{AccessTTL: -time.Minute, RefreshTTL: time.Hour, Issuer: "test"}
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", nil, expiredCfg)
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_MissingHeader(t *testing.T) {
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil))
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_PublicPath(t *testing.T) {
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, []string{"/health"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestAuthMiddleware_PublicPathWildcard(t *testing.T) {
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, []string{"/public/*"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/public/resource", nil))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestAuthMiddleware_UnauthorizedJSON(t *testing.T) {
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil))
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
t.Errorf("want Content-Type application/json, got %q", ct)
}
var body map[string]any
if err := json.NewDecoder(rec.Body).Decode(&body); err != nil {
t.Fatalf("response body is not valid JSON: %v", err)
}
if body["code"] != "UNAUTHENTICATED" {
t.Errorf("want code UNAUTHENTICATED, got %q", body["code"])
}
}
func TestAuthMiddleware_RSAPublicKeyVerifier(t *testing.T) {
verifier := authjwt.NewRSAPublicKeyVerifier(&testRSAKey.PublicKey)
pair, _ := authjwt.IssueTokenPair(testRSA, "uid1", nil, testCfg)
reached := false
h := authjwt.AuthMiddleware(nopLogger{}, verifier, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
if !reached {
t.Error("inner handler was not called")
}
}
func TestAuthMiddleware_SetsTokenData(t *testing.T) {
custom := map[string]any{"role": "admin"}
pair, _ := authjwt.IssueTokenPair(testHMAC, "uid1", custom, testCfg)
var gotUID string
var gotClaims map[string]any
h := authjwt.AuthMiddleware(nopLogger{}, testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotClaims = authmw.GetClaims(r.Context())
if gotClaims != nil {
gotUID, _ = gotClaims["sub"].(string)
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if gotUID != "uid1" {
t.Errorf("want uid=uid1 in claims, got %q", gotUID)
}
if gotClaims == nil {
t.Error("expected claims in context via authmw.GetClaims")
}
}