package httpauth import ( "context" "errors" "net/http" "net/http/httptest" "testing" "time" "code.nochebuena.dev/go/rbac" ) // --- mocks --- type mockEnricher struct { identity rbac.Identity err error } func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (rbac.Identity, error) { return m.identity, m.err } type mockProvider struct { mask rbac.PermissionMask err error } func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (rbac.PermissionMask, error) { return m.mask, m.err } type mockCache struct { val int64 exists bool getErr error setErr error } func (m *mockCache) Get(_ context.Context, _ string) (int64, bool, error) { return m.val, m.exists, m.getErr } func (m *mockCache) Set(_ context.Context, _ string, val int64, _ time.Duration) error { m.val = val return m.setErr } // testPerm is permission bit 1, used in authz tests. const testPerm rbac.Permission = 1 // injectTokenData bypasses an upstream AuthMiddleware for testing downstream middleware. func injectTokenData(uid string, claims map[string]any, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := SetTokenData(r.Context(), uid, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } // --- EnrichmentMiddleware --- func TestEnrichmentMiddleware_Success(t *testing.T) { me := &mockEnricher{identity: rbac.NewIdentity("uid1", "Alice", "alice@example.com")} var got rbac.Identity inner := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got, _ = rbac.FromContext(r.Context()) w.WriteHeader(http.StatusOK) })) h := injectTokenData("uid1", nil, inner) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } if got.UID != "uid1" { t.Errorf("want uid1, got %q", got.UID) } } func TestEnrichmentMiddleware_NoUID(t *testing.T) { me := &mockEnricher{} h := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestEnrichmentMiddleware_EnricherError(t *testing.T) { me := &mockEnricher{err: errors.New("db error")} inner := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) h := injectTokenData("uid1", nil, inner) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusInternalServerError { t.Errorf("want 500, got %d", rec.Code) } } func TestEnrichmentMiddleware_WithTenantHeader(t *testing.T) { me := &mockEnricher{identity: rbac.NewIdentity("uid1", "", "")} var got rbac.Identity inner := EnrichmentMiddleware(me, WithTenantHeader("X-Tenant-ID"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got, _ = rbac.FromContext(r.Context()) w.WriteHeader(http.StatusOK) })) h := injectTokenData("uid1", nil, inner) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Tenant-ID", "tenant-abc") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if got.TenantID != "tenant-abc" { t.Errorf("want tenant-abc, got %q", got.TenantID) } } func TestEnrichmentMiddleware_TenantHeaderAbsent(t *testing.T) { me := &mockEnricher{identity: rbac.NewIdentity("uid1", "", "")} var got rbac.Identity inner := EnrichmentMiddleware(me, WithTenantHeader("X-Tenant-ID"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got, _ = rbac.FromContext(r.Context()) w.WriteHeader(http.StatusOK) })) h := injectTokenData("uid1", nil, inner) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if got.TenantID != "" { t.Errorf("want empty TenantID, got %q", got.TenantID) } } // --- AuthzMiddleware --- func TestAuthzMiddleware_Allowed(t *testing.T) { mp := &mockProvider{mask: rbac.PermissionMask(0).Grant(testPerm)} inner := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid1", "", "")) rec := httptest.NewRecorder() inner.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestAuthzMiddleware_Denied(t *testing.T) { mp := &mockProvider{mask: rbac.PermissionMask(0)} inner := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid1", "", "")) rec := httptest.NewRecorder() inner.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusForbidden { t.Errorf("want 403, got %d", rec.Code) } } func TestAuthzMiddleware_NoIdentity(t *testing.T) { mp := &mockProvider{} h := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } func TestAuthzMiddleware_ProviderError(t *testing.T) { mp := &mockProvider{err: errors.New("db error")} inner := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid1", "", "")) rec := httptest.NewRecorder() inner.ServeHTTP(rec, req.WithContext(ctx)) if rec.Code != http.StatusForbidden { t.Errorf("want 403, got %d", rec.Code) } } // --- ClaimsPermissionProvider --- func TestClaimsPermissionProvider_Float64(t *testing.T) { p := NewClaimsPermissionProvider("permisos") ctx := SetTokenData(context.Background(), "uid1", map[string]any{ "permisos": map[string]any{"usuarios": float64(515)}, }) mask, err := p.ResolveMask(ctx, "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 515 { t.Errorf("want 515, got %d", mask) } } func TestClaimsPermissionProvider_Int64(t *testing.T) { p := NewClaimsPermissionProvider("permisos") ctx := SetTokenData(context.Background(), "uid1", map[string]any{ "permisos": map[string]any{"roles": int64(6)}, }) mask, err := p.ResolveMask(ctx, "uid1", "roles") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 6 { t.Errorf("want 6, got %d", mask) } } func TestClaimsPermissionProvider_NoClaims(t *testing.T) { p := NewClaimsPermissionProvider("permisos") mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 0 { t.Errorf("want 0, got %d", mask) } } func TestClaimsPermissionProvider_ResourceAbsent(t *testing.T) { p := NewClaimsPermissionProvider("permisos") ctx := SetTokenData(context.Background(), "uid1", map[string]any{ "permisos": map[string]any{}, }) mask, err := p.ResolveMask(ctx, "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 0 { t.Errorf("want 0, got %d", mask) } } // --- CachedPermissionProvider --- func TestCachedPermissionProvider_Hit(t *testing.T) { inner := &mockProvider{mask: 999} cache := &mockCache{val: 515, exists: true} p := NewCachedPermissionProvider(inner, cache, time.Minute) mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 515 { t.Errorf("want 515 from cache, got %d", mask) } } func TestCachedPermissionProvider_Miss(t *testing.T) { inner := &mockProvider{mask: 515} cache := &mockCache{exists: false} p := NewCachedPermissionProvider(inner, cache, time.Minute) mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 515 { t.Errorf("want 515 from inner, got %d", mask) } if cache.val != 515 { t.Errorf("expected cache populated with 515, got %d", cache.val) } } func TestCachedPermissionProvider_CacheErrorFallsThrough(t *testing.T) { inner := &mockProvider{mask: 515} cache := &mockCache{getErr: errors.New("valkey unavailable")} p := NewCachedPermissionProvider(inner, cache, time.Minute) mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 515 { t.Errorf("want 515 from inner on cache error, got %d", mask) } } func TestCachedPermissionProvider_InnerError(t *testing.T) { inner := &mockProvider{err: errors.New("db error")} cache := &mockCache{exists: false} p := NewCachedPermissionProvider(inner, cache, time.Minute) _, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err == nil { t.Error("expected error from inner, got nil") } } // --- ChainPermissionProvider --- func TestChainPermissionProvider_FirstNonZero(t *testing.T) { first := &mockProvider{mask: 515} second := &mockProvider{mask: 999} p := NewChainPermissionProvider(first, second) mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 515 { t.Errorf("want 515 from first provider, got %d", mask) } } func TestChainPermissionProvider_Fallthrough(t *testing.T) { first := &mockProvider{mask: 0} second := &mockProvider{mask: 42} p := NewChainPermissionProvider(first, second) mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 42 { t.Errorf("want 42 from second provider, got %d", mask) } } func TestChainPermissionProvider_AllZero(t *testing.T) { p := NewChainPermissionProvider(&mockProvider{mask: 0}, &mockProvider{mask: 0}) mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 0 { t.Errorf("want 0 when all providers return 0, got %d", mask) } } func TestChainPermissionProvider_ErrorPropagates(t *testing.T) { first := &mockProvider{err: errors.New("db error")} second := &mockProvider{mask: 42} p := NewChainPermissionProvider(first, second) _, err := p.ResolveMask(context.Background(), "uid1", "usuarios") if err == nil { t.Error("expected error from first provider, got nil") } }