package auth_test import ( "context" "errors" "fmt" "go/ast" "go/parser" "go/token" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "code.nochebuena.dev/einherjar/auth/authmw" "code.nochebuena.dev/einherjar/auth/rbac" "code.nochebuena.dev/einherjar/contracts/security" "code.nochebuena.dev/einherjar/core/logz" ) // ── Compile-time interface satisfaction ────────────────────────────────────── var _ authmw.IdentityEnricher = (*mockEnricher)(nil) var _ security.PermissionProvider = rbac.NewClaimsPermissionProvider("x", authmw.GetClaims) var _ security.PermissionProvider = rbac.NewCachedPermissionProvider(rbac.NewClaimsPermissionProvider("x", authmw.GetClaims), &mockCache{}, time.Minute) var _ security.PermissionProvider = rbac.NewChainPermissionProvider() var _ rbac.Cache = (*mockCache)(nil) // ── Structural: at most one exported TypeSpec per file ──────────────────────── func TestAtMostOneExportedTypePerFile(t *testing.T) { fset := token.NewFileSet() err := filepath.WalkDir(".", func(path string, d os.DirEntry, err error) error { if err != nil { return err } if d.IsDir() && (d.Name() == ".git" || d.Name() == "vendor") { return filepath.SkipDir } if !strings.HasSuffix(path, ".go") { return nil } if strings.HasSuffix(path, "_test.go") { return nil } if filepath.Base(path) == "doc.go" { return nil } f, parseErr := parser.ParseFile(fset, path, nil, 0) if parseErr != nil { t.Errorf("%s: parse error: %v", path, parseErr) return nil } if count := countExportedTypes(f); count > 1 { t.Errorf("%s: has %d exported type declarations; want at most 1", path, count) } return nil }) if err != nil { t.Fatalf("walk error: %v", err) } } func countExportedTypes(f *ast.File) int { count := 0 for _, decl := range f.Decls { gd, ok := decl.(*ast.GenDecl) if !ok { continue } for _, spec := range gd.Specs { ts, ok := spec.(*ast.TypeSpec) if ok && ts.Name.IsExported() { count++ } } } return count } // ── authmw: SetTokenData ────────────────────────────────────────────────────── func TestSetTokenData(t *testing.T) { ctx := context.Background() claims := map[string]any{"perms": map[string]any{"orders": int64(7)}} ctx = authmw.SetTokenData(ctx, "uid-1", claims) got := authmw.GetClaims(ctx) if got == nil { t.Fatal("GetClaims returned nil after SetTokenData") } if _, ok := got["perms"]; !ok { t.Error("claims key 'perms' missing") } } // ── authmw: EnrichmentMiddleware ────────────────────────────────────────────── func TestEnrichmentMiddlewareSuccess(t *testing.T) { logger := logz.New(logz.Config{}) enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")} h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler()) r := httptest.NewRequest(http.MethodGet, "/", nil) r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil)) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status: got %d, want 200", w.Code) } if !enricher.called { t.Error("enricher was not called") } } func TestEnrichmentMiddlewareMissingUID(t *testing.T) { logger := logz.New(logz.Config{}) enricher := &mockEnricher{} h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler()) r := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusUnauthorized { t.Errorf("status: got %d, want 401", w.Code) } if enricher.called { t.Error("enricher should not be called when uid is missing") } } func TestEnrichmentMiddlewareEnricherError(t *testing.T) { logger := logz.New(logz.Config{}) enricher := &mockEnricher{err: errors.New("db down")} h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler()) r := httptest.NewRequest(http.MethodGet, "/", nil) r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil)) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusInternalServerError { t.Errorf("status: got %d, want 500", w.Code) } } func TestEnrichmentMiddlewareTenantHeader(t *testing.T) { logger := logz.New(logz.Config{}) enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")} h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithTenantHeader("X-Tenant-ID"))( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id, ok := security.FromContext(r.Context()) if !ok { t.Error("identity not in context") } if id.TenantID != "tenant-abc" { t.Errorf("TenantID: got %q, want %q", id.TenantID, "tenant-abc") } w.WriteHeader(http.StatusOK) }), ) r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Tenant-ID", "tenant-abc") r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil)) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status: got %d, want 200", w.Code) } } // TestEnrichmentMiddlewareBagEnricher verifies that WithBagEnricher attaches a // custom attribute to the SecurityBag stored in context. func TestEnrichmentMiddlewareBagEnricher(t *testing.T) { logger := logz.New(logz.Config{}) enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")} hwEnricher := authmw.BagEnricher(func(bag security.SecurityBag, r *http.Request) security.SecurityBag { return bag.With("hardware_id", r.Header.Get("X-Hardware-ID")) }) h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithBagEnricher(hwEnricher))( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bag, ok := security.BagFromContext(r.Context()) if !ok { t.Error("bag not in context") w.WriteHeader(http.StatusInternalServerError) return } v, ok := bag.Get("hardware_id") if !ok { t.Error("hardware_id not in bag") } else if v != "hw-abc" { t.Errorf("hardware_id: got %v, want hw-abc", v) } w.WriteHeader(http.StatusOK) }), ) r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Hardware-ID", "hw-abc") r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil)) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status: got %d, want 200", w.Code) } } // TestEnrichmentMiddlewareBagInContext verifies that security.BagFromContext works // after enrichment, and that the bag carries the enriched Identity. func TestEnrichmentMiddlewareBagInContext(t *testing.T) { logger := logz.New(logz.Config{}) id := security.NewIdentity("uid-1", "Alice", "alice@example.com") enricher := &mockEnricher{identity: id} h := authmw.EnrichmentMiddleware(logger, enricher)( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bag, ok := security.BagFromContext(r.Context()) if !ok { t.Error("BagFromContext: want ok=true, got false") w.WriteHeader(http.StatusInternalServerError) return } if bag.Identity() != id { t.Errorf("bag identity: got %+v, want %+v", bag.Identity(), id) } w.WriteHeader(http.StatusOK) }), ) r := httptest.NewRequest(http.MethodGet, "/", nil) r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil)) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status: got %d, want 200", w.Code) } } // ── authmw: AuthzMiddleware ─────────────────────────────────────────────────── func TestAuthzMiddlewareAllowed(t *testing.T) { logger := logz.New(logz.Config{}) const ReadOrders = security.Permission(0) provider := &mockProvider{mask: security.PermissionMask(0).Grant(ReadOrders)} h := authmw.AuthzMiddleware(logger, provider, "orders", ReadOrders)(okHandler()) id := security.NewIdentity("uid-1", "Alice", "alice@example.com") ctx := security.SetInContext(context.Background(), id) r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status: got %d, want 200", w.Code) } } func TestAuthzMiddlewareMissingIdentity(t *testing.T) { logger := logz.New(logz.Config{}) provider := &mockProvider{} h := authmw.AuthzMiddleware(logger, provider, "orders", 0)(okHandler()) r := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusUnauthorized { t.Errorf("status: got %d, want 401", w.Code) } } func TestAuthzMiddlewarePermissionDenied(t *testing.T) { logger := logz.New(logz.Config{}) provider := &mockProvider{mask: 0} // no permissions h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler()) id := security.NewIdentity("uid-1", "Alice", "alice@example.com") ctx := security.SetInContext(context.Background(), id) r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusForbidden { t.Errorf("status: got %d, want 403", w.Code) } } func TestAuthzMiddlewareProviderError(t *testing.T) { logger := logz.New(logz.Config{}) provider := &mockProvider{err: errors.New("db error")} h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler()) id := security.NewIdentity("uid-1", "Alice", "alice@example.com") ctx := security.SetInContext(context.Background(), id) r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != http.StatusForbidden { t.Errorf("status: got %d, want 403 (fail-closed on provider error)", w.Code) } } // ── rbac: ClaimsPermissionProvider ─────────────────────────────────────────── func TestClaimsProviderHit(t *testing.T) { p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims) claims := map[string]any{ "perms": map[string]any{"orders": int64(7)}, } ctx := authmw.SetTokenData(context.Background(), "uid-1", claims) mask, err := p.ResolveMask(ctx, "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 7 { t.Errorf("mask: got %d, want 7", mask) } } func TestClaimsProviderWildcard(t *testing.T) { p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims) claims := map[string]any{ "perms": map[string]any{"*": int64(15)}, } ctx := authmw.SetTokenData(context.Background(), "uid-1", claims) mask, err := p.ResolveMask(ctx, "uid-1", "products") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 15 { t.Errorf("mask: got %d, want 15 (wildcard)", mask) } } func TestClaimsProviderMissing(t *testing.T) { p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims) ctx := context.Background() // no claims set mask, err := p.ResolveMask(ctx, "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 0 { t.Errorf("mask: got %d, want 0", mask) } } func TestClaimsProviderFloat64(t *testing.T) { p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims) claims := map[string]any{ "perms": map[string]any{"orders": float64(3)}, } ctx := authmw.SetTokenData(context.Background(), "uid-1", claims) mask, err := p.ResolveMask(ctx, "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 3 { t.Errorf("mask: got %d, want 3 (float64 cast)", mask) } } // ── rbac: CachedPermissionProvider ─────────────────────────────────────────── func TestCachedProviderHit(t *testing.T) { inner := &mockProvider{mask: 7} cache := &mockCache{data: map[string]int64{"rbac:uid-1:orders": 99}} p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute) mask, err := p.ResolveMask(context.Background(), "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 99 { t.Errorf("mask: got %d, want 99 (from cache)", mask) } if inner.calls != 0 { t.Error("inner provider should not be called on cache hit") } } func TestCachedProviderMiss(t *testing.T) { inner := &mockProvider{mask: 7} cache := &mockCache{data: map[string]int64{}} p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute) mask, err := p.ResolveMask(context.Background(), "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 7 { t.Errorf("mask: got %d, want 7 (from inner)", mask) } if inner.calls != 1 { t.Error("inner provider should be called on cache miss") } if cache.data["rbac:uid-1:orders"] != 7 { t.Error("result should be stored in cache after miss") } } func TestCachedProviderCacheError(t *testing.T) { inner := &mockProvider{mask: 5} cache := &mockCache{getErr: errors.New("valkey down")} p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute) mask, err := p.ResolveMask(context.Background(), "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 5 { t.Errorf("mask: got %d, want 5 (fallthrough on cache error)", mask) } } func TestCachedProviderTenantKey(t *testing.T) { inner := &mockProvider{mask: 3} cache := &mockCache{data: map[string]int64{}} p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute) id := security.NewIdentity("uid-1", "Alice", "alice@example.com").WithTenant("tenant-abc") ctx := security.SetInContext(context.Background(), id) _, err := p.ResolveMask(ctx, "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } const wantKey = "rbac:tenant-abc:uid-1:orders" if _, ok := cache.data[wantKey]; !ok { t.Errorf("expected cache key %q; got keys: %v", wantKey, cache.keys()) } if _, ok := cache.data["rbac:uid-1:orders"]; ok { t.Error("tenant-scoped key must not omit tenantID") } } // TestWithCacheKey verifies that a custom CachedOpt key function is used instead // of the default, enabling bag attributes (e.g. hardware ID) to be part of the key. func TestWithCacheKey(t *testing.T) { inner := &mockProvider{mask: 7} cache := &mockCache{data: map[string]int64{}} p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute, rbac.WithCacheKey(func(bag security.SecurityBag, uid, resource string) string { hwID, _ := bag.Get("hardware_id") return fmt.Sprintf("custom:%s:%v:%s", uid, hwID, resource) }), ) bag := security.NewSecurityBag(security.NewIdentity("uid-1", "Alice", "a@b.com")). With("hardware_id", "hw-xyz") ctx := security.SetBagInContext(context.Background(), bag) _, err := p.ResolveMask(ctx, "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } const wantKey = "custom:uid-1:hw-xyz:orders" if _, ok := cache.data[wantKey]; !ok { t.Errorf("expected custom cache key %q; got keys: %v", wantKey, cache.keys()) } } // ── rbac: ChainPermissionProvider ──────────────────────────────────────────── func TestChainProviderFirstNonZero(t *testing.T) { first := &mockProvider{mask: 7} second := &mockProvider{mask: 99} p := rbac.NewChainPermissionProvider(first, second) mask, err := p.ResolveMask(context.Background(), "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 7 { t.Errorf("mask: got %d, want 7 (first non-zero)", mask) } if second.calls != 0 { t.Error("second provider should not be called when first returns non-zero") } } func TestChainProviderFallthrough(t *testing.T) { first := &mockProvider{mask: 0} second := &mockProvider{mask: 5} p := rbac.NewChainPermissionProvider(first, second) mask, err := p.ResolveMask(context.Background(), "uid-1", "orders") if err != nil { t.Fatalf("unexpected error: %v", err) } if mask != 5 { t.Errorf("mask: got %d, want 5 (fallthrough to second)", mask) } } func TestChainProviderError(t *testing.T) { first := &mockProvider{err: errors.New("provider down")} second := &mockProvider{mask: 99} p := rbac.NewChainPermissionProvider(first, second) _, err := p.ResolveMask(context.Background(), "uid-1", "orders") if err == nil { t.Fatal("expected error, got nil") } if second.calls != 0 { t.Error("second provider should not be called after error in first") } } // ── helpers ─────────────────────────────────────────────────────────────────── func okHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) } type mockEnricher struct { identity security.Identity err error called bool } func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (security.Identity, error) { m.called = true return m.identity, m.err } type mockProvider struct { mask security.PermissionMask err error calls int } func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (security.PermissionMask, error) { m.calls++ return m.mask, m.err } type mockCache struct { data map[string]int64 getErr error } func (m *mockCache) Get(_ context.Context, key string) (int64, bool, error) { if m.getErr != nil { return 0, false, m.getErr } v, ok := m.data[key] return v, ok, nil } func (m *mockCache) Set(_ context.Context, key string, value int64, _ time.Duration) error { if m.data == nil { m.data = map[string]int64{} } m.data[key] = value return nil } func (m *mockCache) keys() []string { keys := make([]string, 0, len(m.data)) for k := range m.data { keys = append(keys, k) } return keys }