Files
httpauth/httpauth_test.go

302 lines
9.1 KiB
Go
Raw Permalink Normal View History

package httpauth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"code.nochebuena.dev/go/rbac"
)
// --- mocks ---
type mockEnricher struct {
identity rbac.Identity
err error
}
func (m *mockEnricher) Enrich(_ context.Context, _ string, _ map[string]any) (rbac.Identity, error) {
return m.identity, m.err
}
type mockProvider struct {
mask rbac.PermissionMask
err error
}
func (m *mockProvider) ResolveMask(_ context.Context, _, _ string) (rbac.PermissionMask, error) {
return m.mask, m.err
}
type mockCache struct {
val int64
exists bool
getErr error
setErr error
}
func (m *mockCache) Get(_ context.Context, _ string) (int64, bool, error) {
return m.val, m.exists, m.getErr
}
func (m *mockCache) Set(_ context.Context, _ string, val int64, _ time.Duration) error {
m.val = val
return m.setErr
}
// testPerm is permission bit 1, used in authz tests.
const testPerm rbac.Permission = 1
// injectTokenData bypasses an upstream AuthMiddleware for testing downstream middleware.
func injectTokenData(uid string, claims map[string]any, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := SetTokenData(r.Context(), uid, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// --- EnrichmentMiddleware ---
func TestEnrichmentMiddleware_Success(t *testing.T) {
me := &mockEnricher{identity: rbac.NewIdentity("uid1", "Alice", "alice@example.com")}
var got rbac.Identity
inner := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, _ = rbac.FromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
h := injectTokenData("uid1", nil, inner)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
if got.UID != "uid1" {
t.Errorf("want uid1, got %q", got.UID)
}
}
func TestEnrichmentMiddleware_NoUID(t *testing.T) {
me := &mockEnricher{}
h := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestEnrichmentMiddleware_EnricherError(t *testing.T) {
me := &mockEnricher{err: errors.New("db error")}
inner := EnrichmentMiddleware(me)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
h := injectTokenData("uid1", nil, inner)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusInternalServerError {
t.Errorf("want 500, got %d", rec.Code)
}
}
func TestEnrichmentMiddleware_WithTenantHeader(t *testing.T) {
me := &mockEnricher{identity: rbac.NewIdentity("uid1", "", "")}
var got rbac.Identity
inner := EnrichmentMiddleware(me, WithTenantHeader("X-Tenant-ID"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, _ = rbac.FromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
h := injectTokenData("uid1", nil, inner)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Tenant-ID", "tenant-abc")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if got.TenantID != "tenant-abc" {
t.Errorf("want tenant-abc, got %q", got.TenantID)
}
}
func TestEnrichmentMiddleware_TenantHeaderAbsent(t *testing.T) {
me := &mockEnricher{identity: rbac.NewIdentity("uid1", "", "")}
var got rbac.Identity
inner := EnrichmentMiddleware(me, WithTenantHeader("X-Tenant-ID"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, _ = rbac.FromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
h := injectTokenData("uid1", nil, inner)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if got.TenantID != "" {
t.Errorf("want empty TenantID, got %q", got.TenantID)
}
}
// --- AuthzMiddleware ---
func TestAuthzMiddleware_Allowed(t *testing.T) {
mp := &mockProvider{mask: rbac.PermissionMask(0).Grant(testPerm)}
inner := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid1", "", ""))
rec := httptest.NewRecorder()
inner.ServeHTTP(rec, req.WithContext(ctx))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestAuthzMiddleware_Denied(t *testing.T) {
mp := &mockProvider{mask: rbac.PermissionMask(0)}
inner := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid1", "", ""))
rec := httptest.NewRecorder()
inner.ServeHTTP(rec, req.WithContext(ctx))
if rec.Code != http.StatusForbidden {
t.Errorf("want 403, got %d", rec.Code)
}
}
func TestAuthzMiddleware_NoIdentity(t *testing.T) {
mp := &mockProvider{}
h := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusUnauthorized {
t.Errorf("want 401, got %d", rec.Code)
}
}
func TestAuthzMiddleware_ProviderError(t *testing.T) {
mp := &mockProvider{err: errors.New("db error")}
inner := AuthzMiddleware(mp, "resource", testPerm)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := rbac.SetInContext(req.Context(), rbac.NewIdentity("uid1", "", ""))
rec := httptest.NewRecorder()
inner.ServeHTTP(rec, req.WithContext(ctx))
if rec.Code != http.StatusForbidden {
t.Errorf("want 403, got %d", rec.Code)
}
}
// --- ClaimsPermissionProvider ---
func TestClaimsPermissionProvider_Float64(t *testing.T) {
p := NewClaimsPermissionProvider("permisos")
ctx := SetTokenData(context.Background(), "uid1", map[string]any{
"permisos": map[string]any{"usuarios": float64(515)},
})
mask, err := p.ResolveMask(ctx, "uid1", "usuarios")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 515 {
t.Errorf("want 515, got %d", mask)
}
}
func TestClaimsPermissionProvider_Int64(t *testing.T) {
p := NewClaimsPermissionProvider("permisos")
ctx := SetTokenData(context.Background(), "uid1", map[string]any{
"permisos": map[string]any{"roles": int64(6)},
})
mask, err := p.ResolveMask(ctx, "uid1", "roles")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 6 {
t.Errorf("want 6, got %d", mask)
}
}
func TestClaimsPermissionProvider_NoClaims(t *testing.T) {
p := NewClaimsPermissionProvider("permisos")
mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 0 {
t.Errorf("want 0, got %d", mask)
}
}
func TestClaimsPermissionProvider_ResourceAbsent(t *testing.T) {
p := NewClaimsPermissionProvider("permisos")
ctx := SetTokenData(context.Background(), "uid1", map[string]any{
"permisos": map[string]any{},
})
mask, err := p.ResolveMask(ctx, "uid1", "usuarios")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 0 {
t.Errorf("want 0, got %d", mask)
}
}
// --- CachedPermissionProvider ---
func TestCachedPermissionProvider_Hit(t *testing.T) {
inner := &mockProvider{mask: 999}
cache := &mockCache{val: 515, exists: true}
p := NewCachedPermissionProvider(inner, cache, time.Minute)
mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 515 {
t.Errorf("want 515 from cache, got %d", mask)
}
}
func TestCachedPermissionProvider_Miss(t *testing.T) {
inner := &mockProvider{mask: 515}
cache := &mockCache{exists: false}
p := NewCachedPermissionProvider(inner, cache, time.Minute)
mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 515 {
t.Errorf("want 515 from inner, got %d", mask)
}
if cache.val != 515 {
t.Errorf("expected cache populated with 515, got %d", cache.val)
}
}
func TestCachedPermissionProvider_CacheErrorFallsThrough(t *testing.T) {
inner := &mockProvider{mask: 515}
cache := &mockCache{getErr: errors.New("valkey unavailable")}
p := NewCachedPermissionProvider(inner, cache, time.Minute)
mask, err := p.ResolveMask(context.Background(), "uid1", "usuarios")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if mask != 515 {
t.Errorf("want 515 from inner on cache error, got %d", mask)
}
}
func TestCachedPermissionProvider_InnerError(t *testing.T) {
inner := &mockProvider{err: errors.New("db error")}
cache := &mockCache{exists: false}
p := NewCachedPermissionProvider(inner, cache, time.Minute)
_, err := p.ResolveMask(context.Background(), "uid1", "usuarios")
if err == nil {
t.Error("expected error from inner, got nil")
}
}