feat(auth): initial implementation — authmw and rbac (v1.0.0)
Introduces code.nochebuena.dev/einherjar/auth — the provider-agnostic HTTP
authentication and authorization layer of the Einherjar framework. Absorbs
two micro-lib packages (httpauth, rbac) as sub-packages, replacing the
Identity-only context model with a SecurityBag-native design and adding a
composable enrichment chain.
authmw:
- BagEnricher function type — enriches the request-scoped SecurityBag after
the base Identity is built; registered via WithBagEnricher; multiple
enrichers run in order, each receiving the bag from the previous
- IdentityEnricher interface — application-layer contract for loading user
data from uid+claims
- EnrichmentMiddleware — builds SecurityBag from uid+claims, runs enricher
chain, stores via security.SetBagInContext; 401 on missing uid, 500 on
enricher error; routes all errors through httputil.Error
- AuthzMiddleware — per-route permission gate; 401 on missing identity,
403 on provider error (fail-closed) or insufficient permissions
- EnrichOpt type + WithTenantHeader (reads TenantID from header, implemented
as a BagEnricher) + WithBagEnricher (registers custom enrichers for
hardware IDs, grant codes, or any bag attribute)
- SetTokenData / GetClaims — integration contract for auth-jwt / auth-firebase
rbac:
- NewClaimsPermissionProvider — reads flat JWT claim bitmasks from context;
wildcard "*" fallback; handles int64/float64/json.Number; zero DB calls
- NewCachedPermissionProvider — TTL cache wrapping any PermissionProvider;
default key "rbac:{uid}:{resource}" or "rbac:{tenantID}:{uid}:{resource}";
TenantID sourced from SecurityBag automatically; accepts ...CachedOpt
- CachedOpt type + WithCacheKey — overrides the key function for extra
dimensions (hardware IDs, grant codes read from bag attributes)
- NewChainPermissionProvider — tries providers in order; first non-zero wins;
errors short-circuit; typical pattern: claims → cached DB fallback
- Cache interface — pluggable backend satisfied by cache-valkey via duck typing
Compliance test (package auth_test) enforces CT-6 (≤1 exported TypeSpec/file),
compile-time interface satisfaction, and behavioural coverage across the full
middleware and provider surface: enrichment success/failure, tenant header,
custom BagEnricher, bag-in-context, authz allowed/denied/error, claims
hit/wildcard/missing/float64, cached hit/miss/error/tenant-key/custom-key,
chain first-non-zero/fallthrough/error.
Depends on contracts v1.0.0, core v1.0.0, web v1.0.0.
- identifiable.go: package-level Module variable (observability.Identifiable) for version
identification — auth is middleware-only; not registered with the launcher
This commit is contained in:
592
compliance_test.go
Normal file
592
compliance_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.nochebuena.dev/einherjar/auth/authmw"
|
||||
"code.nochebuena.dev/einherjar/auth/rbac"
|
||||
"code.nochebuena.dev/einherjar/contracts/security"
|
||||
"code.nochebuena.dev/einherjar/core/logz"
|
||||
)
|
||||
|
||||
// ── Compile-time interface satisfaction ──────────────────────────────────────
|
||||
|
||||
var _ authmw.IdentityEnricher = (*mockEnricher)(nil)
|
||||
var _ security.PermissionProvider = rbac.NewClaimsPermissionProvider("x", authmw.GetClaims)
|
||||
var _ security.PermissionProvider = rbac.NewCachedPermissionProvider(rbac.NewClaimsPermissionProvider("x", authmw.GetClaims), &mockCache{}, time.Minute)
|
||||
var _ security.PermissionProvider = rbac.NewChainPermissionProvider()
|
||||
var _ rbac.Cache = (*mockCache)(nil)
|
||||
|
||||
// ── Structural: at most one exported TypeSpec per file ────────────────────────
|
||||
|
||||
func TestAtMostOneExportedTypePerFile(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
err := filepath.WalkDir(".", func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() && (d.Name() == ".git" || d.Name() == "vendor") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if !strings.HasSuffix(path, ".go") {
|
||||
return nil
|
||||
}
|
||||
if strings.HasSuffix(path, "_test.go") {
|
||||
return nil
|
||||
}
|
||||
if filepath.Base(path) == "doc.go" {
|
||||
return nil
|
||||
}
|
||||
f, parseErr := parser.ParseFile(fset, path, nil, 0)
|
||||
if parseErr != nil {
|
||||
t.Errorf("%s: parse error: %v", path, parseErr)
|
||||
return nil
|
||||
}
|
||||
if count := countExportedTypes(f); count > 1 {
|
||||
t.Errorf("%s: has %d exported type declarations; want at most 1", path, count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("walk error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func countExportedTypes(f *ast.File) int {
|
||||
count := 0
|
||||
for _, decl := range f.Decls {
|
||||
gd, ok := decl.(*ast.GenDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, spec := range gd.Specs {
|
||||
ts, ok := spec.(*ast.TypeSpec)
|
||||
if ok && ts.Name.IsExported() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// ── authmw: SetTokenData ──────────────────────────────────────────────────────
|
||||
|
||||
func TestSetTokenData(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
claims := map[string]any{"perms": map[string]any{"orders": int64(7)}}
|
||||
ctx = authmw.SetTokenData(ctx, "uid-1", claims)
|
||||
|
||||
got := authmw.GetClaims(ctx)
|
||||
if got == nil {
|
||||
t.Fatal("GetClaims returned nil after SetTokenData")
|
||||
}
|
||||
if _, ok := got["perms"]; !ok {
|
||||
t.Error("claims key 'perms' missing")
|
||||
}
|
||||
}
|
||||
|
||||
// ── authmw: EnrichmentMiddleware ──────────────────────────────────────────────
|
||||
|
||||
func TestEnrichmentMiddlewareSuccess(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
|
||||
|
||||
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status: got %d, want 200", w.Code)
|
||||
}
|
||||
if !enricher.called {
|
||||
t.Error("enricher was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichmentMiddlewareMissingUID(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
enricher := &mockEnricher{}
|
||||
|
||||
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status: got %d, want 401", w.Code)
|
||||
}
|
||||
if enricher.called {
|
||||
t.Error("enricher should not be called when uid is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichmentMiddlewareEnricherError(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
enricher := &mockEnricher{err: errors.New("db down")}
|
||||
|
||||
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status: got %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichmentMiddlewareTenantHeader(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
|
||||
|
||||
h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithTenantHeader("X-Tenant-ID"))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id, ok := security.FromContext(r.Context())
|
||||
if !ok {
|
||||
t.Error("identity not in context")
|
||||
}
|
||||
if id.TenantID != "tenant-abc" {
|
||||
t.Errorf("TenantID: got %q, want %q", id.TenantID, "tenant-abc")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("X-Tenant-ID", "tenant-abc")
|
||||
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status: got %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnrichmentMiddlewareBagEnricher verifies that WithBagEnricher attaches a
|
||||
// custom attribute to the SecurityBag stored in context.
|
||||
func TestEnrichmentMiddlewareBagEnricher(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
|
||||
|
||||
hwEnricher := authmw.BagEnricher(func(bag security.SecurityBag, r *http.Request) security.SecurityBag {
|
||||
return bag.With("hardware_id", r.Header.Get("X-Hardware-ID"))
|
||||
})
|
||||
|
||||
h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithBagEnricher(hwEnricher))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bag, ok := security.BagFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Error("bag not in context")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
v, ok := bag.Get("hardware_id")
|
||||
if !ok {
|
||||
t.Error("hardware_id not in bag")
|
||||
} else if v != "hw-abc" {
|
||||
t.Errorf("hardware_id: got %v, want hw-abc", v)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("X-Hardware-ID", "hw-abc")
|
||||
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status: got %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnrichmentMiddlewareBagInContext verifies that security.BagFromContext works
|
||||
// after enrichment, and that the bag carries the enriched Identity.
|
||||
func TestEnrichmentMiddlewareBagInContext(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||||
enricher := &mockEnricher{identity: id}
|
||||
|
||||
h := authmw.EnrichmentMiddleware(logger, enricher)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bag, ok := security.BagFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Error("BagFromContext: want ok=true, got false")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if bag.Identity() != id {
|
||||
t.Errorf("bag identity: got %+v, want %+v", bag.Identity(), id)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status: got %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── authmw: AuthzMiddleware ───────────────────────────────────────────────────
|
||||
|
||||
func TestAuthzMiddlewareAllowed(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
const ReadOrders = security.Permission(0)
|
||||
provider := &mockProvider{mask: security.PermissionMask(0).Grant(ReadOrders)}
|
||||
|
||||
h := authmw.AuthzMiddleware(logger, provider, "orders", ReadOrders)(okHandler())
|
||||
|
||||
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||||
ctx := security.SetInContext(context.Background(), id)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status: got %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzMiddlewareMissingIdentity(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
provider := &mockProvider{}
|
||||
|
||||
h := authmw.AuthzMiddleware(logger, provider, "orders", 0)(okHandler())
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status: got %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzMiddlewarePermissionDenied(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
provider := &mockProvider{mask: 0} // no permissions
|
||||
|
||||
h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler())
|
||||
|
||||
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||||
ctx := security.SetInContext(context.Background(), id)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status: got %d, want 403", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzMiddlewareProviderError(t *testing.T) {
|
||||
logger := logz.New(logz.Config{})
|
||||
provider := &mockProvider{err: errors.New("db error")}
|
||||
|
||||
h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler())
|
||||
|
||||
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||||
ctx := security.SetInContext(context.Background(), id)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status: got %d, want 403 (fail-closed on provider error)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ── rbac: ClaimsPermissionProvider ───────────────────────────────────────────
|
||||
|
||||
func TestClaimsProviderHit(t *testing.T) {
|
||||
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||||
claims := map[string]any{
|
||||
"perms": map[string]any{"orders": int64(7)},
|
||||
}
|
||||
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
|
||||
|
||||
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 7 {
|
||||
t.Errorf("mask: got %d, want 7", mask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimsProviderWildcard(t *testing.T) {
|
||||
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||||
claims := map[string]any{
|
||||
"perms": map[string]any{"*": int64(15)},
|
||||
}
|
||||
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
|
||||
|
||||
mask, err := p.ResolveMask(ctx, "uid-1", "products")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 15 {
|
||||
t.Errorf("mask: got %d, want 15 (wildcard)", mask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimsProviderMissing(t *testing.T) {
|
||||
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||||
ctx := context.Background() // no claims set
|
||||
|
||||
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 0 {
|
||||
t.Errorf("mask: got %d, want 0", mask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimsProviderFloat64(t *testing.T) {
|
||||
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||||
claims := map[string]any{
|
||||
"perms": map[string]any{"orders": float64(3)},
|
||||
}
|
||||
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
|
||||
|
||||
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 3 {
|
||||
t.Errorf("mask: got %d, want 3 (float64 cast)", mask)
|
||||
}
|
||||
}
|
||||
|
||||
// ── rbac: CachedPermissionProvider ───────────────────────────────────────────
|
||||
|
||||
func TestCachedProviderHit(t *testing.T) {
|
||||
inner := &mockProvider{mask: 7}
|
||||
cache := &mockCache{data: map[string]int64{"rbac:uid-1:orders": 99}}
|
||||
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||||
|
||||
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 99 {
|
||||
t.Errorf("mask: got %d, want 99 (from cache)", mask)
|
||||
}
|
||||
if inner.calls != 0 {
|
||||
t.Error("inner provider should not be called on cache hit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProviderMiss(t *testing.T) {
|
||||
inner := &mockProvider{mask: 7}
|
||||
cache := &mockCache{data: map[string]int64{}}
|
||||
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||||
|
||||
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 7 {
|
||||
t.Errorf("mask: got %d, want 7 (from inner)", mask)
|
||||
}
|
||||
if inner.calls != 1 {
|
||||
t.Error("inner provider should be called on cache miss")
|
||||
}
|
||||
if cache.data["rbac:uid-1:orders"] != 7 {
|
||||
t.Error("result should be stored in cache after miss")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProviderCacheError(t *testing.T) {
|
||||
inner := &mockProvider{mask: 5}
|
||||
cache := &mockCache{getErr: errors.New("valkey down")}
|
||||
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||||
|
||||
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 5 {
|
||||
t.Errorf("mask: got %d, want 5 (fallthrough on cache error)", mask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProviderTenantKey(t *testing.T) {
|
||||
inner := &mockProvider{mask: 3}
|
||||
cache := &mockCache{data: map[string]int64{}}
|
||||
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||||
|
||||
id := security.NewIdentity("uid-1", "Alice", "alice@example.com").WithTenant("tenant-abc")
|
||||
ctx := security.SetInContext(context.Background(), id)
|
||||
|
||||
_, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
const wantKey = "rbac:tenant-abc:uid-1:orders"
|
||||
if _, ok := cache.data[wantKey]; !ok {
|
||||
t.Errorf("expected cache key %q; got keys: %v", wantKey, cache.keys())
|
||||
}
|
||||
if _, ok := cache.data["rbac:uid-1:orders"]; ok {
|
||||
t.Error("tenant-scoped key must not omit tenantID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWithCacheKey verifies that a custom CachedOpt key function is used instead
|
||||
// of the default, enabling bag attributes (e.g. hardware ID) to be part of the key.
|
||||
func TestWithCacheKey(t *testing.T) {
|
||||
inner := &mockProvider{mask: 7}
|
||||
cache := &mockCache{data: map[string]int64{}}
|
||||
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute,
|
||||
rbac.WithCacheKey(func(bag security.SecurityBag, uid, resource string) string {
|
||||
hwID, _ := bag.Get("hardware_id")
|
||||
return fmt.Sprintf("custom:%s:%v:%s", uid, hwID, resource)
|
||||
}),
|
||||
)
|
||||
|
||||
bag := security.NewSecurityBag(security.NewIdentity("uid-1", "Alice", "a@b.com")).
|
||||
With("hardware_id", "hw-xyz")
|
||||
ctx := security.SetBagInContext(context.Background(), bag)
|
||||
|
||||
_, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
const wantKey = "custom:uid-1:hw-xyz:orders"
|
||||
if _, ok := cache.data[wantKey]; !ok {
|
||||
t.Errorf("expected custom cache key %q; got keys: %v", wantKey, cache.keys())
|
||||
}
|
||||
}
|
||||
|
||||
// ── rbac: ChainPermissionProvider ────────────────────────────────────────────
|
||||
|
||||
func TestChainProviderFirstNonZero(t *testing.T) {
|
||||
first := &mockProvider{mask: 7}
|
||||
second := &mockProvider{mask: 99}
|
||||
p := rbac.NewChainPermissionProvider(first, second)
|
||||
|
||||
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 7 {
|
||||
t.Errorf("mask: got %d, want 7 (first non-zero)", mask)
|
||||
}
|
||||
if second.calls != 0 {
|
||||
t.Error("second provider should not be called when first returns non-zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainProviderFallthrough(t *testing.T) {
|
||||
first := &mockProvider{mask: 0}
|
||||
second := &mockProvider{mask: 5}
|
||||
p := rbac.NewChainPermissionProvider(first, second)
|
||||
|
||||
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mask != 5 {
|
||||
t.Errorf("mask: got %d, want 5 (fallthrough to second)", mask)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainProviderError(t *testing.T) {
|
||||
first := &mockProvider{err: errors.New("provider down")}
|
||||
second := &mockProvider{mask: 99}
|
||||
p := rbac.NewChainPermissionProvider(first, second)
|
||||
|
||||
_, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if second.calls != 0 {
|
||||
t.Error("second provider should not be called after error in first")
|
||||
}
|
||||
}
|
||||
|
||||
// ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func okHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
type mockEnricher struct {
|
||||
identity security.Identity
|
||||
err error
|
||||
called bool
|
||||
}
|
||||
|
||||
func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (security.Identity, error) {
|
||||
m.called = true
|
||||
return m.identity, m.err
|
||||
}
|
||||
|
||||
type mockProvider struct {
|
||||
mask security.PermissionMask
|
||||
err error
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (security.PermissionMask, error) {
|
||||
m.calls++
|
||||
return m.mask, m.err
|
||||
}
|
||||
|
||||
type mockCache struct {
|
||||
data map[string]int64
|
||||
getErr error
|
||||
}
|
||||
|
||||
func (m *mockCache) Get(_ context.Context, key string) (int64, bool, error) {
|
||||
if m.getErr != nil {
|
||||
return 0, false, m.getErr
|
||||
}
|
||||
v, ok := m.data[key]
|
||||
return v, ok, nil
|
||||
}
|
||||
|
||||
func (m *mockCache) Set(_ context.Context, key string, value int64, _ time.Duration) error {
|
||||
if m.data == nil {
|
||||
m.data = map[string]int64{}
|
||||
}
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCache) keys() []string {
|
||||
keys := make([]string, 0, len(m.data))
|
||||
for k := range m.data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
Reference in New Issue
Block a user