package httpauth import ( "context" "errors" "net/http" "net/http/httptest" "testing" "firebase.google.com/go/v4/auth" "code.nochebuena.dev/go/rbac" ) // --- mocks --- type mockVerifier struct { token *auth.Token err error } func (m *mockVerifier) VerifyIDTokenAndCheckRevoked(_ context.Context, _ string) (*auth.Token, error) { return m.token, m.err } type mockEnricher struct { identity rbac.Identity err error } func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (rbac.Identity, error) { return m.identity, m.err } type mockProvider struct { mask rbac.PermissionMask err error } func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (rbac.PermissionMask, error) { return m.mask, m.err } // testRead is permission bit 0, used in authz tests. const testRead rbac.Permission = 0 func chain(mw func(http.Handler) http.Handler, h http.HandlerFunc) http.Handler { return mw(h) } // injectUID bypasses AuthMiddleware for EnrichmentMiddleware tests. func injectUID(uid string, claims map[string]any, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := setTokenData(r.Context(), uid, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } // --- AuthMiddleware --- func TestAuthMiddleware_ValidToken(t *testing.T) { mv := &mockVerifier{token: &auth.Token{UID: "uid123", Claims: map[string]any{"name": "Alice"}}} var capturedUID string h := chain(AuthMiddleware(mv, nil), func(w http.ResponseWriter, r *http.Request) { capturedUID, _ = getUID(r.Context()) w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/api", nil) req.Header.Set("Authorization", "Bearer valid-token") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } if capturedUID != "uid123" { t.Errorf("want uid123, got %q", capturedUID) } } func TestAuthMiddleware_InvalidToken(t *testing.T) { mv := &mockVerifier{err: errors.New("token invalid")} h := chain(AuthMiddleware(mv, nil), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/api", nil) req.Header.Set("Authorization", "Bearer bad-token") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthMiddleware_MissingHeader(t *testing.T) { mv := &mockVerifier{} h := chain(AuthMiddleware(mv, nil), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthMiddleware_PublicPath(t *testing.T) { mv := &mockVerifier{err: errors.New("should not be called")} h := chain(AuthMiddleware(mv, []string{"/health"}), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestAuthMiddleware_PublicPathWildcard(t *testing.T) { mv := &mockVerifier{err: errors.New("should not be called")} h := chain(AuthMiddleware(mv, []string{"/public/*"}), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/public/resource", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } // --- EnrichmentMiddleware --- func TestEnrichmentMiddleware_Success(t *testing.T) { me := &mockEnricher{identity: rbac.NewIdentity("uid123", "Alice", "alice@example.com")} var capturedIdentity rbac.Identity inner := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedIdentity, _ = rbac.FromContext(r.Context()) w.WriteHeader(http.StatusOK) })) h := injectUID("uid123", nil, inner) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } if capturedIdentity.UID != "uid123" { t.Errorf("want uid123, got %q", capturedIdentity.UID) } } func TestEnrichmentMiddleware_NoUID(t *testing.T) { me := &mockEnricher{} h := chain(EnrichmentMiddleware(me), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestEnrichmentMiddleware_EnricherError(t *testing.T) { me := &mockEnricher{err: errors.New("db error")} inner := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) h := injectUID("uid123", nil, inner) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusInternalServerError { t.Errorf("want 500, got %d", rec.Code) } } func TestEnrichmentMiddleware_WithTenant(t *testing.T) { me := &mockEnricher{identity: rbac.NewIdentity("uid123", "", "")} var capturedIdentity rbac.Identity inner := EnrichmentMiddleware(me, WithTenantHeader("X-Tenant-ID"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedIdentity, _ = rbac.FromContext(r.Context()) w.WriteHeader(http.StatusOK) })) h := injectUID("uid123", nil, inner) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Tenant-ID", "tenant-abc") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if capturedIdentity.TenantID != "tenant-abc" { t.Errorf("want tenant-abc, got %q", capturedIdentity.TenantID) } } func TestEnrichmentMiddleware_NoTenantHeader(t *testing.T) { me := &mockEnricher{identity: rbac.NewIdentity("uid123", "", "")} var capturedIdentity rbac.Identity inner := EnrichmentMiddleware(me, WithTenantHeader("X-Tenant-ID"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedIdentity, _ = rbac.FromContext(r.Context()) w.WriteHeader(http.StatusOK) })) h := injectUID("uid123", nil, inner) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if capturedIdentity.TenantID != "" { t.Errorf("want empty TenantID, got %q", capturedIdentity.TenantID) } } // --- AuthzMiddleware --- func TestAuthzMiddleware_Allowed(t *testing.T) { mp := &mockProvider{mask: rbac.PermissionMask(0).Grant(testRead)} inner := AuthzMiddleware(mp, "orders", testRead)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid123", "", "")) rec := httptest.NewRecorder() inner.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestAuthzMiddleware_Denied(t *testing.T) { mp := &mockProvider{mask: rbac.PermissionMask(0)} // no permissions granted inner := AuthzMiddleware(mp, "orders", testRead)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid123", "", "")) rec := httptest.NewRecorder() inner.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusForbidden { t.Errorf("want 403, got %d", rec.Code) } } func TestAuthzMiddleware_NoIdentity(t *testing.T) { mp := &mockProvider{} h := chain(AuthzMiddleware(mp, "orders", testRead), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthzMiddleware_ProviderError(t *testing.T) { mp := &mockProvider{err: errors.New("db error")} inner := AuthzMiddleware(mp, "orders", testRead)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid123", "", "")) rec := httptest.NewRecorder() inner.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusForbidden { t.Errorf("want 403, got %d", rec.Code) } }