593 lines
18 KiB
Go
593 lines
18 KiB
Go
|
|
package auth_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"go/ast"
|
||
|
|
"go/parser"
|
||
|
|
"go/token"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/einherjar/auth/authmw"
|
||
|
|
"code.nochebuena.dev/einherjar/auth/rbac"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/security"
|
||
|
|
"code.nochebuena.dev/einherjar/core/logz"
|
||
|
|
)
|
||
|
|
|
||
|
|
// ── Compile-time interface satisfaction ──────────────────────────────────────
|
||
|
|
|
||
|
|
var _ authmw.IdentityEnricher = (*mockEnricher)(nil)
|
||
|
|
var _ security.PermissionProvider = rbac.NewClaimsPermissionProvider("x", authmw.GetClaims)
|
||
|
|
var _ security.PermissionProvider = rbac.NewCachedPermissionProvider(rbac.NewClaimsPermissionProvider("x", authmw.GetClaims), &mockCache{}, time.Minute)
|
||
|
|
var _ security.PermissionProvider = rbac.NewChainPermissionProvider()
|
||
|
|
var _ rbac.Cache = (*mockCache)(nil)
|
||
|
|
|
||
|
|
// ── Structural: at most one exported TypeSpec per file ────────────────────────
|
||
|
|
|
||
|
|
func TestAtMostOneExportedTypePerFile(t *testing.T) {
|
||
|
|
fset := token.NewFileSet()
|
||
|
|
err := filepath.WalkDir(".", func(path string, d os.DirEntry, err error) error {
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if d.IsDir() && (d.Name() == ".git" || d.Name() == "vendor") {
|
||
|
|
return filepath.SkipDir
|
||
|
|
}
|
||
|
|
if !strings.HasSuffix(path, ".go") {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if strings.HasSuffix(path, "_test.go") {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if filepath.Base(path) == "doc.go" {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
f, parseErr := parser.ParseFile(fset, path, nil, 0)
|
||
|
|
if parseErr != nil {
|
||
|
|
t.Errorf("%s: parse error: %v", path, parseErr)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if count := countExportedTypes(f); count > 1 {
|
||
|
|
t.Errorf("%s: has %d exported type declarations; want at most 1", path, count)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("walk error: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func countExportedTypes(f *ast.File) int {
|
||
|
|
count := 0
|
||
|
|
for _, decl := range f.Decls {
|
||
|
|
gd, ok := decl.(*ast.GenDecl)
|
||
|
|
if !ok {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
for _, spec := range gd.Specs {
|
||
|
|
ts, ok := spec.(*ast.TypeSpec)
|
||
|
|
if ok && ts.Name.IsExported() {
|
||
|
|
count++
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return count
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── authmw: SetTokenData ──────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestSetTokenData(t *testing.T) {
|
||
|
|
ctx := context.Background()
|
||
|
|
claims := map[string]any{"perms": map[string]any{"orders": int64(7)}}
|
||
|
|
ctx = authmw.SetTokenData(ctx, "uid-1", claims)
|
||
|
|
|
||
|
|
got := authmw.GetClaims(ctx)
|
||
|
|
if got == nil {
|
||
|
|
t.Fatal("GetClaims returned nil after SetTokenData")
|
||
|
|
}
|
||
|
|
if _, ok := got["perms"]; !ok {
|
||
|
|
t.Error("claims key 'perms' missing")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── authmw: EnrichmentMiddleware ──────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestEnrichmentMiddlewareSuccess(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
|
||
|
|
|
||
|
|
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
if !enricher.called {
|
||
|
|
t.Error("enricher was not called")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestEnrichmentMiddlewareMissingUID(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
enricher := &mockEnricher{}
|
||
|
|
|
||
|
|
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status: got %d, want 401", w.Code)
|
||
|
|
}
|
||
|
|
if enricher.called {
|
||
|
|
t.Error("enricher should not be called when uid is missing")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestEnrichmentMiddlewareEnricherError(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
enricher := &mockEnricher{err: errors.New("db down")}
|
||
|
|
|
||
|
|
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusInternalServerError {
|
||
|
|
t.Errorf("status: got %d, want 500", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestEnrichmentMiddlewareTenantHeader(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
|
||
|
|
|
||
|
|
h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithTenantHeader("X-Tenant-ID"))(
|
||
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
id, ok := security.FromContext(r.Context())
|
||
|
|
if !ok {
|
||
|
|
t.Error("identity not in context")
|
||
|
|
}
|
||
|
|
if id.TenantID != "tenant-abc" {
|
||
|
|
t.Errorf("TenantID: got %q, want %q", id.TenantID, "tenant-abc")
|
||
|
|
}
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r.Header.Set("X-Tenant-ID", "tenant-abc")
|
||
|
|
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// TestEnrichmentMiddlewareBagEnricher verifies that WithBagEnricher attaches a
|
||
|
|
// custom attribute to the SecurityBag stored in context.
|
||
|
|
func TestEnrichmentMiddlewareBagEnricher(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
|
||
|
|
|
||
|
|
hwEnricher := authmw.BagEnricher(func(bag security.SecurityBag, r *http.Request) security.SecurityBag {
|
||
|
|
return bag.With("hardware_id", r.Header.Get("X-Hardware-ID"))
|
||
|
|
})
|
||
|
|
|
||
|
|
h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithBagEnricher(hwEnricher))(
|
||
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
bag, ok := security.BagFromContext(r.Context())
|
||
|
|
if !ok {
|
||
|
|
t.Error("bag not in context")
|
||
|
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
v, ok := bag.Get("hardware_id")
|
||
|
|
if !ok {
|
||
|
|
t.Error("hardware_id not in bag")
|
||
|
|
} else if v != "hw-abc" {
|
||
|
|
t.Errorf("hardware_id: got %v, want hw-abc", v)
|
||
|
|
}
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r.Header.Set("X-Hardware-ID", "hw-abc")
|
||
|
|
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// TestEnrichmentMiddlewareBagInContext verifies that security.BagFromContext works
|
||
|
|
// after enrichment, and that the bag carries the enriched Identity.
|
||
|
|
func TestEnrichmentMiddlewareBagInContext(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||
|
|
enricher := &mockEnricher{identity: id}
|
||
|
|
|
||
|
|
h := authmw.EnrichmentMiddleware(logger, enricher)(
|
||
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
bag, ok := security.BagFromContext(r.Context())
|
||
|
|
if !ok {
|
||
|
|
t.Error("BagFromContext: want ok=true, got false")
|
||
|
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if bag.Identity() != id {
|
||
|
|
t.Errorf("bag identity: got %+v, want %+v", bag.Identity(), id)
|
||
|
|
}
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── authmw: AuthzMiddleware ───────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestAuthzMiddlewareAllowed(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
const ReadOrders = security.Permission(0)
|
||
|
|
provider := &mockProvider{mask: security.PermissionMask(0).Grant(ReadOrders)}
|
||
|
|
|
||
|
|
h := authmw.AuthzMiddleware(logger, provider, "orders", ReadOrders)(okHandler())
|
||
|
|
|
||
|
|
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||
|
|
ctx := security.SetInContext(context.Background(), id)
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthzMiddlewareMissingIdentity(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
provider := &mockProvider{}
|
||
|
|
|
||
|
|
h := authmw.AuthzMiddleware(logger, provider, "orders", 0)(okHandler())
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status: got %d, want 401", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthzMiddlewarePermissionDenied(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
provider := &mockProvider{mask: 0} // no permissions
|
||
|
|
|
||
|
|
h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler())
|
||
|
|
|
||
|
|
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||
|
|
ctx := security.SetInContext(context.Background(), id)
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusForbidden {
|
||
|
|
t.Errorf("status: got %d, want 403", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthzMiddlewareProviderError(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{})
|
||
|
|
provider := &mockProvider{err: errors.New("db error")}
|
||
|
|
|
||
|
|
h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler())
|
||
|
|
|
||
|
|
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
|
||
|
|
ctx := security.SetInContext(context.Background(), id)
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, r)
|
||
|
|
|
||
|
|
if w.Code != http.StatusForbidden {
|
||
|
|
t.Errorf("status: got %d, want 403 (fail-closed on provider error)", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── rbac: ClaimsPermissionProvider ───────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestClaimsProviderHit(t *testing.T) {
|
||
|
|
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||
|
|
claims := map[string]any{
|
||
|
|
"perms": map[string]any{"orders": int64(7)},
|
||
|
|
}
|
||
|
|
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 7 {
|
||
|
|
t.Errorf("mask: got %d, want 7", mask)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestClaimsProviderWildcard(t *testing.T) {
|
||
|
|
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||
|
|
claims := map[string]any{
|
||
|
|
"perms": map[string]any{"*": int64(15)},
|
||
|
|
}
|
||
|
|
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(ctx, "uid-1", "products")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 15 {
|
||
|
|
t.Errorf("mask: got %d, want 15 (wildcard)", mask)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestClaimsProviderMissing(t *testing.T) {
|
||
|
|
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||
|
|
ctx := context.Background() // no claims set
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 0 {
|
||
|
|
t.Errorf("mask: got %d, want 0", mask)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestClaimsProviderFloat64(t *testing.T) {
|
||
|
|
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
|
||
|
|
claims := map[string]any{
|
||
|
|
"perms": map[string]any{"orders": float64(3)},
|
||
|
|
}
|
||
|
|
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 3 {
|
||
|
|
t.Errorf("mask: got %d, want 3 (float64 cast)", mask)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── rbac: CachedPermissionProvider ───────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestCachedProviderHit(t *testing.T) {
|
||
|
|
inner := &mockProvider{mask: 7}
|
||
|
|
cache := &mockCache{data: map[string]int64{"rbac:uid-1:orders": 99}}
|
||
|
|
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 99 {
|
||
|
|
t.Errorf("mask: got %d, want 99 (from cache)", mask)
|
||
|
|
}
|
||
|
|
if inner.calls != 0 {
|
||
|
|
t.Error("inner provider should not be called on cache hit")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCachedProviderMiss(t *testing.T) {
|
||
|
|
inner := &mockProvider{mask: 7}
|
||
|
|
cache := &mockCache{data: map[string]int64{}}
|
||
|
|
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 7 {
|
||
|
|
t.Errorf("mask: got %d, want 7 (from inner)", mask)
|
||
|
|
}
|
||
|
|
if inner.calls != 1 {
|
||
|
|
t.Error("inner provider should be called on cache miss")
|
||
|
|
}
|
||
|
|
if cache.data["rbac:uid-1:orders"] != 7 {
|
||
|
|
t.Error("result should be stored in cache after miss")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCachedProviderCacheError(t *testing.T) {
|
||
|
|
inner := &mockProvider{mask: 5}
|
||
|
|
cache := &mockCache{getErr: errors.New("valkey down")}
|
||
|
|
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 5 {
|
||
|
|
t.Errorf("mask: got %d, want 5 (fallthrough on cache error)", mask)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCachedProviderTenantKey(t *testing.T) {
|
||
|
|
inner := &mockProvider{mask: 3}
|
||
|
|
cache := &mockCache{data: map[string]int64{}}
|
||
|
|
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
|
||
|
|
|
||
|
|
id := security.NewIdentity("uid-1", "Alice", "alice@example.com").WithTenant("tenant-abc")
|
||
|
|
ctx := security.SetInContext(context.Background(), id)
|
||
|
|
|
||
|
|
_, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
const wantKey = "rbac:tenant-abc:uid-1:orders"
|
||
|
|
if _, ok := cache.data[wantKey]; !ok {
|
||
|
|
t.Errorf("expected cache key %q; got keys: %v", wantKey, cache.keys())
|
||
|
|
}
|
||
|
|
if _, ok := cache.data["rbac:uid-1:orders"]; ok {
|
||
|
|
t.Error("tenant-scoped key must not omit tenantID")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// TestWithCacheKey verifies that a custom CachedOpt key function is used instead
|
||
|
|
// of the default, enabling bag attributes (e.g. hardware ID) to be part of the key.
|
||
|
|
func TestWithCacheKey(t *testing.T) {
|
||
|
|
inner := &mockProvider{mask: 7}
|
||
|
|
cache := &mockCache{data: map[string]int64{}}
|
||
|
|
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute,
|
||
|
|
rbac.WithCacheKey(func(bag security.SecurityBag, uid, resource string) string {
|
||
|
|
hwID, _ := bag.Get("hardware_id")
|
||
|
|
return fmt.Sprintf("custom:%s:%v:%s", uid, hwID, resource)
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
|
||
|
|
bag := security.NewSecurityBag(security.NewIdentity("uid-1", "Alice", "a@b.com")).
|
||
|
|
With("hardware_id", "hw-xyz")
|
||
|
|
ctx := security.SetBagInContext(context.Background(), bag)
|
||
|
|
|
||
|
|
_, err := p.ResolveMask(ctx, "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
const wantKey = "custom:uid-1:hw-xyz:orders"
|
||
|
|
if _, ok := cache.data[wantKey]; !ok {
|
||
|
|
t.Errorf("expected custom cache key %q; got keys: %v", wantKey, cache.keys())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── rbac: ChainPermissionProvider ────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestChainProviderFirstNonZero(t *testing.T) {
|
||
|
|
first := &mockProvider{mask: 7}
|
||
|
|
second := &mockProvider{mask: 99}
|
||
|
|
p := rbac.NewChainPermissionProvider(first, second)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 7 {
|
||
|
|
t.Errorf("mask: got %d, want 7 (first non-zero)", mask)
|
||
|
|
}
|
||
|
|
if second.calls != 0 {
|
||
|
|
t.Error("second provider should not be called when first returns non-zero")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestChainProviderFallthrough(t *testing.T) {
|
||
|
|
first := &mockProvider{mask: 0}
|
||
|
|
second := &mockProvider{mask: 5}
|
||
|
|
p := rbac.NewChainPermissionProvider(first, second)
|
||
|
|
|
||
|
|
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if mask != 5 {
|
||
|
|
t.Errorf("mask: got %d, want 5 (fallthrough to second)", mask)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestChainProviderError(t *testing.T) {
|
||
|
|
first := &mockProvider{err: errors.New("provider down")}
|
||
|
|
second := &mockProvider{mask: 99}
|
||
|
|
p := rbac.NewChainPermissionProvider(first, second)
|
||
|
|
|
||
|
|
_, err := p.ResolveMask(context.Background(), "uid-1", "orders")
|
||
|
|
if err == nil {
|
||
|
|
t.Fatal("expected error, got nil")
|
||
|
|
}
|
||
|
|
if second.calls != 0 {
|
||
|
|
t.Error("second provider should not be called after error in first")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── helpers ───────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func okHandler() http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
type mockEnricher struct {
|
||
|
|
identity security.Identity
|
||
|
|
err error
|
||
|
|
called bool
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (security.Identity, error) {
|
||
|
|
m.called = true
|
||
|
|
return m.identity, m.err
|
||
|
|
}
|
||
|
|
|
||
|
|
type mockProvider struct {
|
||
|
|
mask security.PermissionMask
|
||
|
|
err error
|
||
|
|
calls int
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (security.PermissionMask, error) {
|
||
|
|
m.calls++
|
||
|
|
return m.mask, m.err
|
||
|
|
}
|
||
|
|
|
||
|
|
type mockCache struct {
|
||
|
|
data map[string]int64
|
||
|
|
getErr error
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockCache) Get(_ context.Context, key string) (int64, bool, error) {
|
||
|
|
if m.getErr != nil {
|
||
|
|
return 0, false, m.getErr
|
||
|
|
}
|
||
|
|
v, ok := m.data[key]
|
||
|
|
return v, ok, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockCache) Set(_ context.Context, key string, value int64, _ time.Duration) error {
|
||
|
|
if m.data == nil {
|
||
|
|
m.data = map[string]int64{}
|
||
|
|
}
|
||
|
|
m.data[key] = value
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockCache) keys() []string {
|
||
|
|
keys := make([]string, 0, len(m.data))
|
||
|
|
for k := range m.data {
|
||
|
|
keys = append(keys, k)
|
||
|
|
}
|
||
|
|
return keys
|
||
|
|
}
|