package jwtauth import ( "context" "crypto/rand" "crypto/rsa" "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/golang-jwt/jwt/v5" ) var ( testSecret = []byte("test-secret-key-at-least-32-bytes!") testHMAC = NewHMACSigner(testSecret) testRSAKey = mustGenerateRSA() testRSA = NewRSASigner(testRSAKey) ) func mustGenerateRSA() *rsa.PrivateKey { k, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic(err) } return k } var testCfg = TokenConfig{ AccessTTL: time.Minute, RefreshTTL: 7 * 24 * time.Hour, Issuer: "test-issuer", } // --- Signer --- func TestHMACSigner_SignAndVerify(t *testing.T) { claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, err := testHMAC.Sign(claims) if err != nil { t.Fatalf("Sign: %v", err) } parsed, err := testHMAC.Verify(tok) if err != nil { t.Fatalf("Verify: %v", err) } mc, _ := parsed.Claims.(jwt.MapClaims) if mc["sub"] != "uid1" { t.Errorf("want sub=uid1, got %v", mc["sub"]) } } func TestHMACSigner_TamperedToken(t *testing.T) { claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, _ := testHMAC.Sign(claims) _, err := testHMAC.Verify(tok + "tampered") if err == nil { t.Error("expected error for tampered token") } } func TestHMACSigner_WrongSecret(t *testing.T) { other := NewHMACSigner([]byte("completely-different-secret-key!")) claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, _ := testHMAC.Sign(claims) _, err := other.Verify(tok) if err == nil { t.Error("expected error for wrong secret") } } func TestHMACSigner_AlgMismatch(t *testing.T) { claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, _ := testRSA.Sign(claims) _, err := testHMAC.Verify(tok) if err == nil { t.Error("expected error for algorithm mismatch (RSA token verified with HMAC)") } } func TestRSASigner_SignAndVerify(t *testing.T) { claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, err := testRSA.Sign(claims) if err != nil { t.Fatalf("Sign: %v", err) } parsed, err := testRSA.Verify(tok) if err != nil { t.Fatalf("Verify: %v", err) } mc, _ := parsed.Claims.(jwt.MapClaims) if mc["sub"] != "uid1" { t.Errorf("want sub=uid1, got %v", mc["sub"]) } } func TestRSAPublicKeyVerifier_VerifiesTokenFromSigner(t *testing.T) { verifier := NewRSAPublicKeyVerifier(&testRSAKey.PublicKey) claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, _ := testRSA.Sign(claims) parsed, err := verifier.Verify(tok) if err != nil { t.Fatalf("Verify: %v", err) } mc, _ := parsed.Claims.(jwt.MapClaims) if mc["sub"] != "uid1" { t.Errorf("want sub=uid1, got %v", mc["sub"]) } } func TestRSAPublicKeyVerifier_RejectsHMACToken(t *testing.T) { verifier := NewRSAPublicKeyVerifier(&testRSAKey.PublicKey) claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))} tok, _ := testHMAC.Sign(claims) _, err := verifier.Verify(tok) if err == nil { t.Error("expected error: HMAC token verified with RSA public key") } } // --- IssueTokenPair --- func TestIssueTokenPair_StandardClaims(t *testing.T) { pair, err := IssueTokenPair(testHMAC, "uid1", nil, testCfg) if err != nil { t.Fatalf("IssueTokenPair: %v", err) } if pair.AccessToken == "" || pair.RefreshToken == "" { t.Error("expected non-empty tokens") } if pair.ExpiresIn != int64(testCfg.AccessTTL.Seconds()) { t.Errorf("want ExpiresIn=%d, got %d", int64(testCfg.AccessTTL.Seconds()), pair.ExpiresIn) } tok, err := testHMAC.Verify(pair.AccessToken) if err != nil { t.Fatalf("verify access token: %v", err) } mc, _ := tok.Claims.(jwt.MapClaims) if mc["sub"] != "uid1" { t.Errorf("want sub=uid1, got %v", mc["sub"]) } if mc["iss"] != testCfg.Issuer { t.Errorf("want iss=%s, got %v", testCfg.Issuer, mc["iss"]) } if mc["jti"] == "" { t.Error("expected non-empty jti in access token") } } func TestIssueTokenPair_CustomClaims(t *testing.T) { custom := map[string]any{ "permisos": map[string]any{"usuarios": float64(515)}, } pair, err := IssueTokenPair(testHMAC, "uid1", custom, testCfg) if err != nil { t.Fatalf("IssueTokenPair: %v", err) } tok, _ := testHMAC.Verify(pair.AccessToken) mc, _ := tok.Claims.(jwt.MapClaims) permisos, ok := mc["permisos"].(map[string]any) if !ok { t.Fatalf("permisos claim missing or wrong type: %T", mc["permisos"]) } if permisos["usuarios"] != float64(515) { t.Errorf("want usuarios=515, got %v", permisos["usuarios"]) } } func TestIssueTokenPair_UniqueJTIs(t *testing.T) { p1, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) p2, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) if p1.AccessToken == p2.AccessToken { t.Error("expected unique access tokens across calls") } if p1.RefreshToken == p2.RefreshToken { t.Error("expected unique refresh tokens across calls") } } func TestIssueTokenPair_RefreshHasFam(t *testing.T) { pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) tok, _ := testHMAC.Verify(pair.RefreshToken) mc, _ := tok.Claims.(jwt.MapClaims) if mc["fam"] == "" { t.Error("expected fam claim in refresh token") } } // --- RefreshTokenPair --- type mockBlacklist struct { revoked map[string]bool err error } func newMockBlacklist() *mockBlacklist { return &mockBlacklist{revoked: make(map[string]bool)} } func (m *mockBlacklist) IsRevoked(_ context.Context, jti string) (bool, error) { if m.err != nil { return false, m.err } return m.revoked[jti], nil } func (m *mockBlacklist) Revoke(_ context.Context, jti string, _ time.Duration) error { if m.err != nil { return m.err } m.revoked[jti] = true return nil } func TestRefreshTokenPair_Success(t *testing.T) { bl := newMockBlacklist() pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) newPair, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil) if err != nil { t.Fatalf("RefreshTokenPair: %v", err) } if newPair.AccessToken == "" || newPair.RefreshToken == "" { t.Error("expected non-empty new token pair") } if newPair.RefreshToken == pair.RefreshToken { t.Error("new refresh token must differ from old") } } func TestRefreshTokenPair_OldTokenRevoked(t *testing.T) { bl := newMockBlacklist() pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) if _, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil); err != nil { t.Fatalf("first refresh: %v", err) } _, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil) if !errors.Is(err, ErrTokenRevoked) { t.Errorf("want ErrTokenRevoked, got %v", err) } } func TestRefreshTokenPair_InvalidToken(t *testing.T) { bl := newMockBlacklist() _, err := RefreshTokenPair(context.Background(), testHMAC, "not.a.token", bl, testCfg, nil) if err == nil { t.Error("expected error for invalid token string") } } func TestRefreshTokenPair_BlacklistCheckError(t *testing.T) { bl := &mockBlacklist{revoked: make(map[string]bool), err: errors.New("valkey unavailable")} pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) _, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil) if err == nil { t.Error("expected error when blacklist is unavailable") } } func TestRefreshTokenPair_CustomClaimsInNewToken(t *testing.T) { bl := newMockBlacklist() pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) freshClaims := map[string]any{"permisos": map[string]any{"usuarios": float64(7)}} newPair, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, freshClaims) if err != nil { t.Fatalf("RefreshTokenPair: %v", err) } tok, _ := testHMAC.Verify(newPair.AccessToken) mc, _ := tok.Claims.(jwt.MapClaims) permisos, ok := mc["permisos"].(map[string]any) if !ok { t.Fatalf("permisos missing from new access token") } if permisos["usuarios"] != float64(7) { t.Errorf("want 7, got %v", permisos["usuarios"]) } } // --- AuthMiddleware --- func TestAuthMiddleware_ValidToken(t *testing.T) { pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg) reached := false h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reached = true w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api", nil) req.Header.Set("Authorization", "Bearer "+pair.AccessToken) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } if !reached { t.Error("inner handler was not called") } } func TestAuthMiddleware_InvalidToken(t *testing.T) { h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api", nil) req.Header.Set("Authorization", "Bearer invalid.token.here") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthMiddleware_ExpiredToken(t *testing.T) { expiredCfg := TokenConfig{AccessTTL: -time.Minute, RefreshTTL: time.Hour, Issuer: "test"} pair, _ := IssueTokenPair(testHMAC, "uid1", nil, expiredCfg) h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api", nil) req.Header.Set("Authorization", "Bearer "+pair.AccessToken) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthMiddleware_MissingHeader(t *testing.T) { h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthMiddleware_PublicPath(t *testing.T) { h := AuthMiddleware(testHMAC, []string{"/health"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestAuthMiddleware_PublicPathWildcard(t *testing.T) { h := AuthMiddleware(testHMAC, []string{"/public/*"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/public/resource", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestAuthMiddleware_RSAPublicKeyVerifier(t *testing.T) { verifier := NewRSAPublicKeyVerifier(&testRSAKey.PublicKey) pair, _ := IssueTokenPair(testRSA, "uid1", nil, testCfg) reached := false h := AuthMiddleware(verifier, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reached = true w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api", nil) req.Header.Set("Authorization", "Bearer "+pair.AccessToken) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } if !reached { t.Error("inner handler was not called") } }