Files
auth/compliance_test.go

593 lines
18 KiB
Go
Raw Permalink Normal View History

feat(auth): initial implementation — authmw and rbac (v1.0.0) Introduces code.nochebuena.dev/einherjar/auth — the provider-agnostic HTTP authentication and authorization layer of the Einherjar framework. Absorbs two micro-lib packages (httpauth, rbac) as sub-packages, replacing the Identity-only context model with a SecurityBag-native design and adding a composable enrichment chain. authmw: - BagEnricher function type — enriches the request-scoped SecurityBag after the base Identity is built; registered via WithBagEnricher; multiple enrichers run in order, each receiving the bag from the previous - IdentityEnricher interface — application-layer contract for loading user data from uid+claims - EnrichmentMiddleware — builds SecurityBag from uid+claims, runs enricher chain, stores via security.SetBagInContext; 401 on missing uid, 500 on enricher error; routes all errors through httputil.Error - AuthzMiddleware — per-route permission gate; 401 on missing identity, 403 on provider error (fail-closed) or insufficient permissions - EnrichOpt type + WithTenantHeader (reads TenantID from header, implemented as a BagEnricher) + WithBagEnricher (registers custom enrichers for hardware IDs, grant codes, or any bag attribute) - SetTokenData / GetClaims — integration contract for auth-jwt / auth-firebase rbac: - NewClaimsPermissionProvider — reads flat JWT claim bitmasks from context; wildcard "*" fallback; handles int64/float64/json.Number; zero DB calls - NewCachedPermissionProvider — TTL cache wrapping any PermissionProvider; default key "rbac:{uid}:{resource}" or "rbac:{tenantID}:{uid}:{resource}"; TenantID sourced from SecurityBag automatically; accepts ...CachedOpt - CachedOpt type + WithCacheKey — overrides the key function for extra dimensions (hardware IDs, grant codes read from bag attributes) - NewChainPermissionProvider — tries providers in order; first non-zero wins; errors short-circuit; typical pattern: claims → cached DB fallback - Cache interface — pluggable backend satisfied by cache-valkey via duck typing Compliance test (package auth_test) enforces CT-6 (≤1 exported TypeSpec/file), compile-time interface satisfaction, and behavioural coverage across the full middleware and provider surface: enrichment success/failure, tenant header, custom BagEnricher, bag-in-context, authz allowed/denied/error, claims hit/wildcard/missing/float64, cached hit/miss/error/tenant-key/custom-key, chain first-non-zero/fallthrough/error. Depends on contracts v1.0.0, core v1.0.0, web v1.0.0. - identifiable.go: package-level Module variable (observability.Identifiable) for version identification — auth is middleware-only; not registered with the launcher
2026-05-29 16:11:21 +00:00
package auth_test
import (
"context"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"code.nochebuena.dev/einherjar/auth/authmw"
"code.nochebuena.dev/einherjar/auth/rbac"
"code.nochebuena.dev/einherjar/contracts/security"
"code.nochebuena.dev/einherjar/core/logz"
)
// ── Compile-time interface satisfaction ──────────────────────────────────────
var _ authmw.IdentityEnricher = (*mockEnricher)(nil)
var _ security.PermissionProvider = rbac.NewClaimsPermissionProvider("x", authmw.GetClaims)
var _ security.PermissionProvider = rbac.NewCachedPermissionProvider(rbac.NewClaimsPermissionProvider("x", authmw.GetClaims), &mockCache{}, time.Minute)
var _ security.PermissionProvider = rbac.NewChainPermissionProvider()
var _ rbac.Cache = (*mockCache)(nil)
// ── Structural: at most one exported TypeSpec per file ────────────────────────
func TestAtMostOneExportedTypePerFile(t *testing.T) {
fset := token.NewFileSet()
err := filepath.WalkDir(".", func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() && (d.Name() == ".git" || d.Name() == "vendor") {
return filepath.SkipDir
}
if !strings.HasSuffix(path, ".go") {
return nil
}
if strings.HasSuffix(path, "_test.go") {
return nil
}
if filepath.Base(path) == "doc.go" {
return nil
}
f, parseErr := parser.ParseFile(fset, path, nil, 0)
if parseErr != nil {
t.Errorf("%s: parse error: %v", path, parseErr)
return nil
}
if count := countExportedTypes(f); count > 1 {
t.Errorf("%s: has %d exported type declarations; want at most 1", path, count)
}
return nil
})
if err != nil {
t.Fatalf("walk error: %v", err)
}
}
func countExportedTypes(f *ast.File) int {
count := 0
for _, decl := range f.Decls {
gd, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
for _, spec := range gd.Specs {
ts, ok := spec.(*ast.TypeSpec)
if ok && ts.Name.IsExported() {
count++
}
}
}
return count
}
// ── authmw: SetTokenData ──────────────────────────────────────────────────────
func TestSetTokenData(t *testing.T) {
ctx := context.Background()
claims := map[string]any{"perms": map[string]any{"orders": int64(7)}}
ctx = authmw.SetTokenData(ctx, "uid-1", claims)
got := authmw.GetClaims(ctx)
if got == nil {
t.Fatal("GetClaims returned nil after SetTokenData")
}
if _, ok := got["perms"]; !ok {
t.Error("claims key 'perms' missing")
}
}
// ── authmw: EnrichmentMiddleware ──────────────────────────────────────────────
func TestEnrichmentMiddlewareSuccess(t *testing.T) {
logger := logz.New(logz.Config{})
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status: got %d, want 200", w.Code)
}
if !enricher.called {
t.Error("enricher was not called")
}
}
func TestEnrichmentMiddlewareMissingUID(t *testing.T) {
logger := logz.New(logz.Config{})
enricher := &mockEnricher{}
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusUnauthorized {
t.Errorf("status: got %d, want 401", w.Code)
}
if enricher.called {
t.Error("enricher should not be called when uid is missing")
}
}
func TestEnrichmentMiddlewareEnricherError(t *testing.T) {
logger := logz.New(logz.Config{})
enricher := &mockEnricher{err: errors.New("db down")}
h := authmw.EnrichmentMiddleware(logger, enricher)(okHandler())
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusInternalServerError {
t.Errorf("status: got %d, want 500", w.Code)
}
}
func TestEnrichmentMiddlewareTenantHeader(t *testing.T) {
logger := logz.New(logz.Config{})
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithTenantHeader("X-Tenant-ID"))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id, ok := security.FromContext(r.Context())
if !ok {
t.Error("identity not in context")
}
if id.TenantID != "tenant-abc" {
t.Errorf("TenantID: got %q, want %q", id.TenantID, "tenant-abc")
}
w.WriteHeader(http.StatusOK)
}),
)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("X-Tenant-ID", "tenant-abc")
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status: got %d, want 200", w.Code)
}
}
// TestEnrichmentMiddlewareBagEnricher verifies that WithBagEnricher attaches a
// custom attribute to the SecurityBag stored in context.
func TestEnrichmentMiddlewareBagEnricher(t *testing.T) {
logger := logz.New(logz.Config{})
enricher := &mockEnricher{identity: security.NewIdentity("uid-1", "Alice", "alice@example.com")}
hwEnricher := authmw.BagEnricher(func(bag security.SecurityBag, r *http.Request) security.SecurityBag {
return bag.With("hardware_id", r.Header.Get("X-Hardware-ID"))
})
h := authmw.EnrichmentMiddleware(logger, enricher, authmw.WithBagEnricher(hwEnricher))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bag, ok := security.BagFromContext(r.Context())
if !ok {
t.Error("bag not in context")
w.WriteHeader(http.StatusInternalServerError)
return
}
v, ok := bag.Get("hardware_id")
if !ok {
t.Error("hardware_id not in bag")
} else if v != "hw-abc" {
t.Errorf("hardware_id: got %v, want hw-abc", v)
}
w.WriteHeader(http.StatusOK)
}),
)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("X-Hardware-ID", "hw-abc")
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status: got %d, want 200", w.Code)
}
}
// TestEnrichmentMiddlewareBagInContext verifies that security.BagFromContext works
// after enrichment, and that the bag carries the enriched Identity.
func TestEnrichmentMiddlewareBagInContext(t *testing.T) {
logger := logz.New(logz.Config{})
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
enricher := &mockEnricher{identity: id}
h := authmw.EnrichmentMiddleware(logger, enricher)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bag, ok := security.BagFromContext(r.Context())
if !ok {
t.Error("BagFromContext: want ok=true, got false")
w.WriteHeader(http.StatusInternalServerError)
return
}
if bag.Identity() != id {
t.Errorf("bag identity: got %+v, want %+v", bag.Identity(), id)
}
w.WriteHeader(http.StatusOK)
}),
)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(authmw.SetTokenData(r.Context(), "uid-1", nil))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status: got %d, want 200", w.Code)
}
}
// ── authmw: AuthzMiddleware ───────────────────────────────────────────────────
func TestAuthzMiddlewareAllowed(t *testing.T) {
logger := logz.New(logz.Config{})
const ReadOrders = security.Permission(0)
provider := &mockProvider{mask: security.PermissionMask(0).Grant(ReadOrders)}
h := authmw.AuthzMiddleware(logger, provider, "orders", ReadOrders)(okHandler())
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
ctx := security.SetInContext(context.Background(), id)
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status: got %d, want 200", w.Code)
}
}
func TestAuthzMiddlewareMissingIdentity(t *testing.T) {
logger := logz.New(logz.Config{})
provider := &mockProvider{}
h := authmw.AuthzMiddleware(logger, provider, "orders", 0)(okHandler())
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusUnauthorized {
t.Errorf("status: got %d, want 401", w.Code)
}
}
func TestAuthzMiddlewarePermissionDenied(t *testing.T) {
logger := logz.New(logz.Config{})
provider := &mockProvider{mask: 0} // no permissions
h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler())
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
ctx := security.SetInContext(context.Background(), id)
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusForbidden {
t.Errorf("status: got %d, want 403", w.Code)
}
}
func TestAuthzMiddlewareProviderError(t *testing.T) {
logger := logz.New(logz.Config{})
provider := &mockProvider{err: errors.New("db error")}
h := authmw.AuthzMiddleware(logger, provider, "orders", security.Permission(0))(okHandler())
id := security.NewIdentity("uid-1", "Alice", "alice@example.com")
ctx := security.SetInContext(context.Background(), id)
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != http.StatusForbidden {
t.Errorf("status: got %d, want 403 (fail-closed on provider error)", w.Code)
}
}
// ── rbac: ClaimsPermissionProvider ───────────────────────────────────────────
func TestClaimsProviderHit(t *testing.T) {
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
claims := map[string]any{
"perms": map[string]any{"orders": int64(7)},
}
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 7 {
t.Errorf("mask: got %d, want 7", mask)
}
}
func TestClaimsProviderWildcard(t *testing.T) {
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
claims := map[string]any{
"perms": map[string]any{"*": int64(15)},
}
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
mask, err := p.ResolveMask(ctx, "uid-1", "products")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 15 {
t.Errorf("mask: got %d, want 15 (wildcard)", mask)
}
}
func TestClaimsProviderMissing(t *testing.T) {
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
ctx := context.Background() // no claims set
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 0 {
t.Errorf("mask: got %d, want 0", mask)
}
}
func TestClaimsProviderFloat64(t *testing.T) {
p := rbac.NewClaimsPermissionProvider("perms", authmw.GetClaims)
claims := map[string]any{
"perms": map[string]any{"orders": float64(3)},
}
ctx := authmw.SetTokenData(context.Background(), "uid-1", claims)
mask, err := p.ResolveMask(ctx, "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 3 {
t.Errorf("mask: got %d, want 3 (float64 cast)", mask)
}
}
// ── rbac: CachedPermissionProvider ───────────────────────────────────────────
func TestCachedProviderHit(t *testing.T) {
inner := &mockProvider{mask: 7}
cache := &mockCache{data: map[string]int64{"rbac:uid-1:orders": 99}}
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 99 {
t.Errorf("mask: got %d, want 99 (from cache)", mask)
}
if inner.calls != 0 {
t.Error("inner provider should not be called on cache hit")
}
}
func TestCachedProviderMiss(t *testing.T) {
inner := &mockProvider{mask: 7}
cache := &mockCache{data: map[string]int64{}}
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 7 {
t.Errorf("mask: got %d, want 7 (from inner)", mask)
}
if inner.calls != 1 {
t.Error("inner provider should be called on cache miss")
}
if cache.data["rbac:uid-1:orders"] != 7 {
t.Error("result should be stored in cache after miss")
}
}
func TestCachedProviderCacheError(t *testing.T) {
inner := &mockProvider{mask: 5}
cache := &mockCache{getErr: errors.New("valkey down")}
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 5 {
t.Errorf("mask: got %d, want 5 (fallthrough on cache error)", mask)
}
}
func TestCachedProviderTenantKey(t *testing.T) {
inner := &mockProvider{mask: 3}
cache := &mockCache{data: map[string]int64{}}
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute)
id := security.NewIdentity("uid-1", "Alice", "alice@example.com").WithTenant("tenant-abc")
ctx := security.SetInContext(context.Background(), id)
_, err := p.ResolveMask(ctx, "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
const wantKey = "rbac:tenant-abc:uid-1:orders"
if _, ok := cache.data[wantKey]; !ok {
t.Errorf("expected cache key %q; got keys: %v", wantKey, cache.keys())
}
if _, ok := cache.data["rbac:uid-1:orders"]; ok {
t.Error("tenant-scoped key must not omit tenantID")
}
}
// TestWithCacheKey verifies that a custom CachedOpt key function is used instead
// of the default, enabling bag attributes (e.g. hardware ID) to be part of the key.
func TestWithCacheKey(t *testing.T) {
inner := &mockProvider{mask: 7}
cache := &mockCache{data: map[string]int64{}}
p := rbac.NewCachedPermissionProvider(inner, cache, time.Minute,
rbac.WithCacheKey(func(bag security.SecurityBag, uid, resource string) string {
hwID, _ := bag.Get("hardware_id")
return fmt.Sprintf("custom:%s:%v:%s", uid, hwID, resource)
}),
)
bag := security.NewSecurityBag(security.NewIdentity("uid-1", "Alice", "a@b.com")).
With("hardware_id", "hw-xyz")
ctx := security.SetBagInContext(context.Background(), bag)
_, err := p.ResolveMask(ctx, "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
const wantKey = "custom:uid-1:hw-xyz:orders"
if _, ok := cache.data[wantKey]; !ok {
t.Errorf("expected custom cache key %q; got keys: %v", wantKey, cache.keys())
}
}
// ── rbac: ChainPermissionProvider ────────────────────────────────────────────
func TestChainProviderFirstNonZero(t *testing.T) {
first := &mockProvider{mask: 7}
second := &mockProvider{mask: 99}
p := rbac.NewChainPermissionProvider(first, second)
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 7 {
t.Errorf("mask: got %d, want 7 (first non-zero)", mask)
}
if second.calls != 0 {
t.Error("second provider should not be called when first returns non-zero")
}
}
func TestChainProviderFallthrough(t *testing.T) {
first := &mockProvider{mask: 0}
second := &mockProvider{mask: 5}
p := rbac.NewChainPermissionProvider(first, second)
mask, err := p.ResolveMask(context.Background(), "uid-1", "orders")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 5 {
t.Errorf("mask: got %d, want 5 (fallthrough to second)", mask)
}
}
func TestChainProviderError(t *testing.T) {
first := &mockProvider{err: errors.New("provider down")}
second := &mockProvider{mask: 99}
p := rbac.NewChainPermissionProvider(first, second)
_, err := p.ResolveMask(context.Background(), "uid-1", "orders")
if err == nil {
t.Fatal("expected error, got nil")
}
if second.calls != 0 {
t.Error("second provider should not be called after error in first")
}
}
// ── helpers ───────────────────────────────────────────────────────────────────
func okHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
}
type mockEnricher struct {
identity security.Identity
err error
called bool
}
func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (security.Identity, error) {
m.called = true
return m.identity, m.err
}
type mockProvider struct {
mask security.PermissionMask
err error
calls int
}
func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (security.PermissionMask, error) {
m.calls++
return m.mask, m.err
}
type mockCache struct {
data map[string]int64
getErr error
}
func (m *mockCache) Get(_ context.Context, key string) (int64, bool, error) {
if m.getErr != nil {
return 0, false, m.getErr
}
v, ok := m.data[key]
return v, ok, nil
}
func (m *mockCache) Set(_ context.Context, key string, value int64, _ time.Duration) error {
if m.data == nil {
m.data = map[string]int64{}
}
m.data[key] = value
return nil
}
func (m *mockCache) keys() []string {
keys := make([]string, 0, len(m.data))
for k := range m.data {
keys = append(keys, k)
}
return keys
}