Files
httpauth-jwt/jwtauth_test.go

386 lines
11 KiB
Go
Raw Normal View History

package jwtauth
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
var (
testSecret = []byte("test-secret-key-at-least-32-bytes!")
testHMAC = NewHMACSigner(testSecret)
testRSAKey = mustGenerateRSA()
testRSA = NewRSASigner(testRSAKey)
)
func mustGenerateRSA() *rsa.PrivateKey {
k, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
return k
}
var testCfg = TokenConfig{
AccessTTL: time.Minute,
RefreshTTL: 7 * 24 * time.Hour,
Issuer: "test-issuer",
}
// --- Signer ---
func TestHMACSigner_SignAndVerify(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, err := testHMAC.Sign(claims)
if err != nil {
t.Fatalf("Sign: %v", err)
}
parsed, err := testHMAC.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := parsed.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
}
func TestHMACSigner_TamperedToken(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := testHMAC.Verify(tok + "tampered")
if err == nil {
t.Error("expected error for tampered token")
}
}
func TestHMACSigner_WrongSecret(t *testing.T) {
other := NewHMACSigner([]byte("completely-different-secret-key!"))
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := other.Verify(tok)
if err == nil {
t.Error("expected error for wrong secret")
}
}
func TestHMACSigner_AlgMismatch(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testRSA.Sign(claims)
_, err := testHMAC.Verify(tok)
if err == nil {
t.Error("expected error for algorithm mismatch (RSA token verified with HMAC)")
}
}
func TestRSASigner_SignAndVerify(t *testing.T) {
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, err := testRSA.Sign(claims)
if err != nil {
t.Fatalf("Sign: %v", err)
}
parsed, err := testRSA.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := parsed.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
}
func TestRSAPublicKeyVerifier_VerifiesTokenFromSigner(t *testing.T) {
verifier := NewRSAPublicKeyVerifier(&testRSAKey.PublicKey)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testRSA.Sign(claims)
parsed, err := verifier.Verify(tok)
if err != nil {
t.Fatalf("Verify: %v", err)
}
mc, _ := parsed.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
}
func TestRSAPublicKeyVerifier_RejectsHMACToken(t *testing.T) {
verifier := NewRSAPublicKeyVerifier(&testRSAKey.PublicKey)
claims := jwt.MapClaims{"sub": "uid1", "exp": jwt.NewNumericDate(time.Now().Add(time.Minute))}
tok, _ := testHMAC.Sign(claims)
_, err := verifier.Verify(tok)
if err == nil {
t.Error("expected error: HMAC token verified with RSA public key")
}
}
// --- IssueTokenPair ---
func TestIssueTokenPair_StandardClaims(t *testing.T) {
pair, err := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
if err != nil {
t.Fatalf("IssueTokenPair: %v", err)
}
if pair.AccessToken == "" || pair.RefreshToken == "" {
t.Error("expected non-empty tokens")
}
if pair.ExpiresIn != int64(testCfg.AccessTTL.Seconds()) {
t.Errorf("want ExpiresIn=%d, got %d", int64(testCfg.AccessTTL.Seconds()), pair.ExpiresIn)
}
tok, err := testHMAC.Verify(pair.AccessToken)
if err != nil {
t.Fatalf("verify access token: %v", err)
}
mc, _ := tok.Claims.(jwt.MapClaims)
if mc["sub"] != "uid1" {
t.Errorf("want sub=uid1, got %v", mc["sub"])
}
if mc["iss"] != testCfg.Issuer {
t.Errorf("want iss=%s, got %v", testCfg.Issuer, mc["iss"])
}
if mc["jti"] == "" {
t.Error("expected non-empty jti in access token")
}
}
func TestIssueTokenPair_CustomClaims(t *testing.T) {
custom := map[string]any{
"permisos": map[string]any{"usuarios": float64(515)},
}
pair, err := IssueTokenPair(testHMAC, "uid1", custom, testCfg)
if err != nil {
t.Fatalf("IssueTokenPair: %v", err)
}
tok, _ := testHMAC.Verify(pair.AccessToken)
mc, _ := tok.Claims.(jwt.MapClaims)
permisos, ok := mc["permisos"].(map[string]any)
if !ok {
t.Fatalf("permisos claim missing or wrong type: %T", mc["permisos"])
}
if permisos["usuarios"] != float64(515) {
t.Errorf("want usuarios=515, got %v", permisos["usuarios"])
}
}
func TestIssueTokenPair_UniqueJTIs(t *testing.T) {
p1, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
p2, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
if p1.AccessToken == p2.AccessToken {
t.Error("expected unique access tokens across calls")
}
if p1.RefreshToken == p2.RefreshToken {
t.Error("expected unique refresh tokens across calls")
}
}
func TestIssueTokenPair_RefreshHasFam(t *testing.T) {
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
tok, _ := testHMAC.Verify(pair.RefreshToken)
mc, _ := tok.Claims.(jwt.MapClaims)
if mc["fam"] == "" {
t.Error("expected fam claim in refresh token")
}
}
// --- RefreshTokenPair ---
type mockBlacklist struct {
revoked map[string]bool
err error
}
func newMockBlacklist() *mockBlacklist {
return &mockBlacklist{revoked: make(map[string]bool)}
}
func (m *mockBlacklist) IsRevoked(_ context.Context, jti string) (bool, error) {
if m.err != nil {
return false, m.err
}
return m.revoked[jti], nil
}
func (m *mockBlacklist) Revoke(_ context.Context, jti string, _ time.Duration) error {
if m.err != nil {
return m.err
}
m.revoked[jti] = true
return nil
}
func TestRefreshTokenPair_Success(t *testing.T) {
bl := newMockBlacklist()
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
newPair, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil)
if err != nil {
t.Fatalf("RefreshTokenPair: %v", err)
}
if newPair.AccessToken == "" || newPair.RefreshToken == "" {
t.Error("expected non-empty new token pair")
}
if newPair.RefreshToken == pair.RefreshToken {
t.Error("new refresh token must differ from old")
}
}
func TestRefreshTokenPair_OldTokenRevoked(t *testing.T) {
bl := newMockBlacklist()
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
if _, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil); err != nil {
t.Fatalf("first refresh: %v", err)
}
_, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil)
if !errors.Is(err, ErrTokenRevoked) {
t.Errorf("want ErrTokenRevoked, got %v", err)
}
}
func TestRefreshTokenPair_InvalidToken(t *testing.T) {
bl := newMockBlacklist()
_, err := RefreshTokenPair(context.Background(), testHMAC, "not.a.token", bl, testCfg, nil)
if err == nil {
t.Error("expected error for invalid token string")
}
}
func TestRefreshTokenPair_BlacklistCheckError(t *testing.T) {
bl := &mockBlacklist{revoked: make(map[string]bool), err: errors.New("valkey unavailable")}
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
_, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, nil)
if err == nil {
t.Error("expected error when blacklist is unavailable")
}
}
func TestRefreshTokenPair_CustomClaimsInNewToken(t *testing.T) {
bl := newMockBlacklist()
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
freshClaims := map[string]any{"permisos": map[string]any{"usuarios": float64(7)}}
newPair, err := RefreshTokenPair(context.Background(), testHMAC, pair.RefreshToken, bl, testCfg, freshClaims)
if err != nil {
t.Fatalf("RefreshTokenPair: %v", err)
}
tok, _ := testHMAC.Verify(newPair.AccessToken)
mc, _ := tok.Claims.(jwt.MapClaims)
permisos, ok := mc["permisos"].(map[string]any)
if !ok {
t.Fatalf("permisos missing from new access token")
}
if permisos["usuarios"] != float64(7) {
t.Errorf("want 7, got %v", permisos["usuarios"])
}
}
// --- AuthMiddleware ---
func TestAuthMiddleware_ValidToken(t *testing.T) {
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, testCfg)
reached := false
h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
if !reached {
t.Error("inner handler was not called")
}
}
func TestAuthMiddleware_InvalidToken(t *testing.T) {
h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer invalid.token.here")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
expiredCfg := TokenConfig{AccessTTL: -time.Minute, RefreshTTL: time.Hour, Issuer: "test"}
pair, _ := IssueTokenPair(testHMAC, "uid1", nil, expiredCfg)
h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_MissingHeader(t *testing.T) {
h := AuthMiddleware(testHMAC, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil))
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_PublicPath(t *testing.T) {
h := AuthMiddleware(testHMAC, []string{"/health"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestAuthMiddleware_PublicPathWildcard(t *testing.T) {
h := AuthMiddleware(testHMAC, []string{"/public/*"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/public/resource", nil))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestAuthMiddleware_RSAPublicKeyVerifier(t *testing.T) {
verifier := NewRSAPublicKeyVerifier(&testRSAKey.PublicKey)
pair, _ := IssueTokenPair(testRSA, "uid1", nil, testCfg)
reached := false
h := AuthMiddleware(verifier, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
if !reached {
t.Error("inner handler was not called")
}
}