package contracts_test import ( "context" "go/ast" "go/parser" "go/token" "os" "path/filepath" "strings" "testing" "code.nochebuena.dev/einherjar/contracts/observability" "code.nochebuena.dev/einherjar/contracts/security" ) // TestOneExportPerFile asserts that every non-doc, non-test .go source file in the // contracts module exports exactly one top-level declaration. This mechanically // enforces CT-6: one file per interface or type. func TestOneExportPerFile(t *testing.T) { t.Parallel() root := "." fset := token.NewFileSet() err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.IsDir() && info.Name() == ".git" { return filepath.SkipDir } if !strings.HasSuffix(path, ".go") { return nil } if strings.HasSuffix(path, "_test.go") { return nil } if filepath.Base(path) == "doc.go" { return nil } f, parseErr := parser.ParseFile(fset, path, nil, 0) if parseErr != nil { t.Errorf("parse error in %s: %v", path, parseErr) return nil } count := countExportedTypes(f) if count != 1 { t.Errorf("%s: want exactly 1 exported type declaration, got %d", path, count) } return nil }) if err != nil { t.Fatalf("walk error: %v", err) } } // countExportedTypes counts top-level exported type declarations in a parsed file. // Constants, variables, and functions are not counted — they are part of a type's // API and may coexist with it in the same file. CT-6 requires one type per file, // not one symbol per file. func countExportedTypes(f *ast.File) int { count := 0 for _, decl := range f.Decls { if gd, ok := decl.(*ast.GenDecl); ok { for _, spec := range gd.Specs { if ts, ok := spec.(*ast.TypeSpec); ok && ast.IsExported(ts.Name.Name) { count++ } } } } return count } // TestIdentityIsValueType verifies that Identity.WithTenant returns a new value // without mutating the receiver. func TestIdentityIsValueType(t *testing.T) { t.Parallel() original := security.NewIdentity("uid-1", "Alice", "alice@example.com") enriched := original.WithTenant("tenant-abc") if original.TenantID != "" { t.Errorf("WithTenant mutated receiver: original.TenantID = %q, want empty", original.TenantID) } if enriched.TenantID != "tenant-abc" { t.Errorf("WithTenant: enriched.TenantID = %q, want %q", enriched.TenantID, "tenant-abc") } if enriched.UID != original.UID { t.Errorf("WithTenant: enriched.UID = %q, want %q", enriched.UID, original.UID) } } // TestIdentityContextRoundtrip verifies SetInContext and FromContext. func TestIdentityContextRoundtrip(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-2", "Bob", "bob@example.com") ctx := security.SetInContext(context.Background(), id) got, ok := security.FromContext(ctx) if !ok { t.Fatal("FromContext: want ok=true, got false") } if got != id { t.Errorf("FromContext: got %+v, want %+v", got, id) } } // TestFromContextMissing verifies FromContext returns zero value and false on an // empty context. func TestFromContextMissing(t *testing.T) { t.Parallel() got, ok := security.FromContext(context.Background()) if ok { t.Error("FromContext on empty context: want ok=false, got true") } if got != (security.Identity{}) { t.Errorf("FromContext on empty context: got %+v, want zero value", got) } } // TestPermissionMaskRoundtrip verifies Grant and Has. func TestPermissionMaskRoundtrip(t *testing.T) { t.Parallel() const read security.Permission = 0 const write security.Permission = 1 const del security.Permission = 2 mask := security.PermissionMask(0).Grant(read).Grant(write) if !mask.Has(read) { t.Error("Has(read): want true, got false") } if !mask.Has(write) { t.Error("Has(write): want true, got false") } if mask.Has(del) { t.Error("Has(del): want false, got true") } } // TestPermissionMaskBoundaries verifies out-of-range positions are rejected. func TestPermissionMaskBoundaries(t *testing.T) { t.Parallel() full := security.PermissionMask(^int64(0)) if full.Has(-1) { t.Error("Has(-1): want false, got true") } if full.Has(63) { t.Error("Has(63): want false, got true") } if full.Has(security.MaxPermission + 1) { t.Errorf("Has(MaxPermission+1): want false, got true") } } // TestSecurityBagNewAndGet verifies construction and attribute access. func TestSecurityBagNewAndGet(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-1", "Alice", "alice@example.com") bag := security.NewSecurityBag(id) if bag.Identity() != id { t.Errorf("Identity: got %+v, want %+v", bag.Identity(), id) } if _, ok := bag.Get("missing"); ok { t.Error("Get on empty bag: want ok=false, got true") } } // TestSecurityBagWith verifies that With stores a value and returns a new bag. func TestSecurityBagWith(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-1", "Alice", "alice@example.com") bag := security.NewSecurityBag(id).With("hardware_id", "hw-abc") v, ok := bag.Get("hardware_id") if !ok { t.Fatal("Get: want ok=true, got false") } if v != "hw-abc" { t.Errorf("Get: got %v, want hw-abc", v) } } // TestSecurityBagImmutability verifies that With does not mutate the original bag. func TestSecurityBagImmutability(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-1", "Alice", "alice@example.com") original := security.NewSecurityBag(id) _ = original.With("key", "value") if _, ok := original.Get("key"); ok { t.Error("With mutated the original bag: key should be absent in original") } } // TestSecurityBagWithIdentity verifies WithIdentity replaces the Identity. func TestSecurityBagWithIdentity(t *testing.T) { t.Parallel() id1 := security.NewIdentity("uid-1", "Alice", "alice@example.com") id2 := security.NewIdentity("uid-2", "Bob", "bob@example.com") bag := security.NewSecurityBag(id1).With("k", "v").WithIdentity(id2) if bag.Identity() != id2 { t.Errorf("WithIdentity: got %+v, want %+v", bag.Identity(), id2) } v, ok := bag.Get("k") if !ok || v != "v" { t.Error("WithIdentity must preserve existing attributes") } } // TestBagContextRoundtrip verifies SetBagInContext and BagFromContext. func TestBagContextRoundtrip(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-1", "Alice", "alice@example.com") bag := security.NewSecurityBag(id).With("hardware_id", "hw-abc") ctx := security.SetBagInContext(context.Background(), bag) got, ok := security.BagFromContext(ctx) if !ok { t.Fatal("BagFromContext: want ok=true, got false") } if got.Identity() != id { t.Errorf("BagFromContext identity: got %+v, want %+v", got.Identity(), id) } v, ok := got.Get("hardware_id") if !ok || v != "hw-abc" { t.Errorf("BagFromContext attribute: got %v ok=%v, want hw-abc true", v, ok) } } // TestSetInContextReadableBag verifies that SetInContext stores a SecurityBag // readable by BagFromContext (backward-compatible storage upgrade). func TestSetInContextReadableBag(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-1", "Alice", "alice@example.com") ctx := security.SetInContext(context.Background(), id) bag, ok := security.BagFromContext(ctx) if !ok { t.Fatal("BagFromContext after SetInContext: want ok=true, got false") } if bag.Identity() != id { t.Errorf("BagFromContext identity: got %+v, want %+v", bag.Identity(), id) } } // TestSetBagInContextReadableIdentity verifies that SetBagInContext stores a value // readable by FromContext (forward-compatible for existing callers). func TestSetBagInContextReadableIdentity(t *testing.T) { t.Parallel() id := security.NewIdentity("uid-1", "Alice", "alice@example.com") bag := security.NewSecurityBag(id).With("extra", "val") ctx := security.SetBagInContext(context.Background(), bag) got, ok := security.FromContext(ctx) if !ok { t.Fatal("FromContext after SetBagInContext: want ok=true, got false") } if got != id { t.Errorf("FromContext identity: got %+v, want %+v", got, id) } } // TestBagFromContextMissing verifies BagFromContext returns false on empty context. func TestBagFromContextMissing(t *testing.T) { t.Parallel() _, ok := security.BagFromContext(context.Background()) if ok { t.Error("BagFromContext on empty context: want ok=false, got true") } } // TestLevelCriticalIsZeroValue verifies that the zero value of Level is LevelCritical. func TestLevelCriticalIsZeroValue(t *testing.T) { t.Parallel() var l observability.Level if l != observability.LevelCritical { t.Errorf("zero Level = %d, want LevelCritical (%d)", l, observability.LevelCritical) } }