441 lines
13 KiB
Go
441 lines
13 KiB
Go
|
|
package web_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"go/ast"
|
||
|
|
"go/parser"
|
||
|
|
"go/token"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/lifecycle"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/observability"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/security"
|
||
|
|
"code.nochebuena.dev/einherjar/core/logz"
|
||
|
|
"code.nochebuena.dev/einherjar/core/valid"
|
||
|
|
"code.nochebuena.dev/einherjar/core/xerrors"
|
||
|
|
web "code.nochebuena.dev/einherjar/web"
|
||
|
|
"code.nochebuena.dev/einherjar/web/health"
|
||
|
|
"code.nochebuena.dev/einherjar/web/httputil"
|
||
|
|
"code.nochebuena.dev/einherjar/web/mw"
|
||
|
|
"code.nochebuena.dev/einherjar/web/server"
|
||
|
|
)
|
||
|
|
|
||
|
|
// ── Compile-time interface satisfaction ──────────────────────────────────────
|
||
|
|
|
||
|
|
var _ mw.RateLimiterStore = (*mw.InMemoryRateLimiterStore)(nil)
|
||
|
|
|
||
|
|
// server.Server embeds lifecycle.Component — verified at compile time via assignment.
|
||
|
|
func init() {
|
||
|
|
var s server.Server = server.New(logz.New(logz.Config{}), server.Config{})
|
||
|
|
var _ lifecycle.Component = s
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── Structural: at most one exported TypeSpec per file ────────────────────────
|
||
|
|
|
||
|
|
func TestAtMostOneExportedTypePerFile(t *testing.T) {
|
||
|
|
fset := token.NewFileSet()
|
||
|
|
err := filepath.WalkDir(".", func(path string, d os.DirEntry, err error) error {
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if d.IsDir() && (d.Name() == ".git" || d.Name() == "vendor") {
|
||
|
|
return filepath.SkipDir
|
||
|
|
}
|
||
|
|
if !strings.HasSuffix(path, ".go") {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if strings.HasSuffix(path, "_test.go") {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if filepath.Base(path) == "doc.go" {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
f, parseErr := parser.ParseFile(fset, path, nil, 0)
|
||
|
|
if parseErr != nil {
|
||
|
|
t.Errorf("%s: parse error: %v", path, parseErr)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if count := countExportedTypes(f); count > 1 {
|
||
|
|
t.Errorf("%s: has %d exported type declarations; want at most 1", path, count)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("walk error: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func countExportedTypes(f *ast.File) int {
|
||
|
|
count := 0
|
||
|
|
for _, decl := range f.Decls {
|
||
|
|
gd, ok := decl.(*ast.GenDecl)
|
||
|
|
if !ok {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
for _, spec := range gd.Specs {
|
||
|
|
ts, ok := spec.(*ast.TypeSpec)
|
||
|
|
if ok && ts.Name.IsExported() {
|
||
|
|
count++
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return count
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── server ────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestServerConfigDefaults(t *testing.T) {
|
||
|
|
cfg := server.Config{}
|
||
|
|
if cfg.Host != "" {
|
||
|
|
t.Errorf("Host zero value should be empty string (defaulted at runtime), got %q", cfg.Host)
|
||
|
|
}
|
||
|
|
if cfg.Port != 0 {
|
||
|
|
t.Errorf("Port zero value should be 0 (defaulted at runtime), got %d", cfg.Port)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerNew(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
srv := server.New(logger, server.Config{})
|
||
|
|
if srv == nil {
|
||
|
|
t.Fatal("server.New returned nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── health ────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
type mockCheckable struct {
|
||
|
|
name string
|
||
|
|
priority observability.Level
|
||
|
|
err error
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockCheckable) HealthCheck(_ context.Context) error { return m.err }
|
||
|
|
func (m *mockCheckable) Name() string { return m.name }
|
||
|
|
func (m *mockCheckable) Priority() observability.Level { return m.priority }
|
||
|
|
|
||
|
|
var _ observability.Checkable = (*mockCheckable)(nil)
|
||
|
|
|
||
|
|
func TestHealthConfigDefaultTimeout(t *testing.T) {
|
||
|
|
// Zero Config should still work — defaultCheckTimeout applied inside handler.
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
h := health.NewHandlerWithConfig(logger, health.Config{})
|
||
|
|
if h == nil {
|
||
|
|
t.Fatal("NewHandlerWithConfig returned nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHealthHandlerAllUp(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
c := &mockCheckable{name: "db", priority: observability.LevelCritical}
|
||
|
|
h := health.NewHandler(logger, c)
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/health", nil))
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
var resp health.Response
|
||
|
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||
|
|
t.Fatalf("decode: %v", err)
|
||
|
|
}
|
||
|
|
if resp.Status != "UP" {
|
||
|
|
t.Errorf("overall status: got %q, want UP", resp.Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHealthHandlerCriticalDown(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
c := &mockCheckable{name: "db", priority: observability.LevelCritical, err: errors.New("connection refused")}
|
||
|
|
h := health.NewHandler(logger, c)
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/health", nil))
|
||
|
|
|
||
|
|
if w.Code != http.StatusServiceUnavailable {
|
||
|
|
t.Errorf("status: got %d, want 503", w.Code)
|
||
|
|
}
|
||
|
|
var resp health.Response
|
||
|
|
json.NewDecoder(w.Body).Decode(&resp)
|
||
|
|
if resp.Status != "DOWN" {
|
||
|
|
t.Errorf("overall status: got %q, want DOWN", resp.Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHealthHandlerDegradedDown(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
c := &mockCheckable{name: "cache", priority: observability.LevelDegraded, err: errors.New("timeout")}
|
||
|
|
h := health.NewHandler(logger, c)
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/health", nil))
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200 (degraded is not 503)", w.Code)
|
||
|
|
}
|
||
|
|
var resp health.Response
|
||
|
|
json.NewDecoder(w.Body).Decode(&resp)
|
||
|
|
if resp.Status != "DEGRADED" {
|
||
|
|
t.Errorf("overall status: got %q, want DEGRADED", resp.Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── httputil ──────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestHTTPUtilErrorMapping(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
cases := []struct {
|
||
|
|
code xerrors.Code
|
||
|
|
wantHTTP int
|
||
|
|
}{
|
||
|
|
{xerrors.ErrInvalidInput, 400},
|
||
|
|
{xerrors.ErrOutOfRange, 400},
|
||
|
|
{xerrors.ErrUnauthorized, 401},
|
||
|
|
{xerrors.ErrPermissionDenied, 403},
|
||
|
|
{xerrors.ErrNotFound, 404},
|
||
|
|
{xerrors.ErrAlreadyExists, 409},
|
||
|
|
{xerrors.ErrAborted, 409},
|
||
|
|
{xerrors.ErrGone, 410},
|
||
|
|
{xerrors.ErrPreconditionFailed, 412},
|
||
|
|
{xerrors.ErrRateLimited, 429},
|
||
|
|
{xerrors.ErrCancelled, 499},
|
||
|
|
{xerrors.ErrInternal, 500},
|
||
|
|
{xerrors.ErrDataLoss, 500},
|
||
|
|
{xerrors.ErrNotImplemented, 501},
|
||
|
|
{xerrors.ErrUnavailable, 503},
|
||
|
|
{xerrors.ErrDeadlineExceeded, 504},
|
||
|
|
}
|
||
|
|
for _, tc := range cases {
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
httputil.Error(logger, w, r, xerrors.New(tc.code, "test"))
|
||
|
|
if w.Code != tc.wantHTTP {
|
||
|
|
t.Errorf("code %s: HTTP status got %d, want %d", tc.code, w.Code, tc.wantHTTP)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPUtilHandle(t *testing.T) {
|
||
|
|
type req struct {
|
||
|
|
Name string `json:"name" validate:"required"`
|
||
|
|
}
|
||
|
|
type res struct {
|
||
|
|
Greeting string `json:"greeting"`
|
||
|
|
}
|
||
|
|
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
v := valid.New()
|
||
|
|
h := httputil.Handle(v, logger, func(_ context.Context, r req) (res, error) {
|
||
|
|
return res{Greeting: "hello " + r.Name}, nil
|
||
|
|
})
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"name":"world"}`)))
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
var out res
|
||
|
|
if err := json.NewDecoder(w.Body).Decode(&out); err != nil {
|
||
|
|
t.Fatalf("decode: %v", err)
|
||
|
|
}
|
||
|
|
if out.Greeting != "hello world" {
|
||
|
|
t.Errorf("greeting: got %q, want %q", out.Greeting, "hello world")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPUtilHandleValidationError(t *testing.T) {
|
||
|
|
type req struct {
|
||
|
|
Name string `json:"name" validate:"required"`
|
||
|
|
}
|
||
|
|
type res struct{}
|
||
|
|
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
v := valid.New()
|
||
|
|
h := httputil.Handle(v, logger, func(_ context.Context, r req) (res, error) {
|
||
|
|
return res{}, nil
|
||
|
|
})
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)))
|
||
|
|
|
||
|
|
if w.Code != http.StatusBadRequest {
|
||
|
|
t.Errorf("status: got %d, want 400", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPUtilHandleNoBody(t *testing.T) {
|
||
|
|
type res struct {
|
||
|
|
Value int `json:"value"`
|
||
|
|
}
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
h := httputil.HandleNoBody(logger, func(_ context.Context) (res, error) {
|
||
|
|
return res{Value: 42}, nil
|
||
|
|
})
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("status: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPUtilHandleEmpty(t *testing.T) {
|
||
|
|
type req struct {
|
||
|
|
ID string `json:"id" validate:"required"`
|
||
|
|
}
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
v := valid.New()
|
||
|
|
h := httputil.HandleEmpty(v, logger, func(_ context.Context, r req) error {
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(w, httptest.NewRequest(http.MethodDelete, "/", strings.NewReader(`{"id":"abc"}`)))
|
||
|
|
|
||
|
|
if w.Code != http.StatusNoContent {
|
||
|
|
t.Errorf("status: got %d, want 204", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── mw ────────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestMWStatusRecorder(t *testing.T) {
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
rec := &mw.StatusRecorder{ResponseWriter: w, Status: http.StatusOK}
|
||
|
|
rec.WriteHeader(http.StatusCreated)
|
||
|
|
if rec.Status != http.StatusCreated {
|
||
|
|
t.Errorf("Status: got %d, want 201", rec.Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestInMemoryRateLimiterStore(t *testing.T) {
|
||
|
|
// burst=1, rps very low — first request passes, second is denied
|
||
|
|
store := mw.NewInMemoryRateLimiterStore(0.001, 1)
|
||
|
|
ctx := context.Background()
|
||
|
|
|
||
|
|
ok, err := store.Allow(ctx, "key1")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("first request should be allowed")
|
||
|
|
}
|
||
|
|
|
||
|
|
ok2, err2 := store.Allow(ctx, "key1")
|
||
|
|
if err2 != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err2)
|
||
|
|
}
|
||
|
|
if ok2 {
|
||
|
|
t.Fatal("second request should be denied after burst exhausted")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestIPRateLimit(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
store := mw.NewInMemoryRateLimiterStore(0.001, 1)
|
||
|
|
handler := mw.IPRateLimit(store, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
make429 := func() int {
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r.RemoteAddr = "10.0.0.1:1234"
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, r)
|
||
|
|
return w.Code
|
||
|
|
}
|
||
|
|
|
||
|
|
// First request passes (burst=1)
|
||
|
|
if code := make429(); code != http.StatusOK {
|
||
|
|
t.Errorf("first request: got %d, want 200", code)
|
||
|
|
}
|
||
|
|
// Second request is rate limited
|
||
|
|
if code := make429(); code != http.StatusTooManyRequests {
|
||
|
|
t.Errorf("second request: got %d, want 429", code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestIPRateLimitFailOpen(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
handler := mw.IPRateLimit(&errorStore{}, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r.RemoteAddr = "10.0.0.1:1234"
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, r)
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("fail-open: got %d, want 200", w.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestUserRateLimit(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
store := mw.NewInMemoryRateLimiterStore(0.001, 1)
|
||
|
|
handler := mw.UserRateLimit(store, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
id := security.NewIdentity("user-abc", "Alice", "alice@example.com")
|
||
|
|
ctx := security.SetInContext(context.Background(), id)
|
||
|
|
|
||
|
|
makeReq := func() int {
|
||
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, r)
|
||
|
|
return w.Code
|
||
|
|
}
|
||
|
|
|
||
|
|
if code := makeReq(); code != http.StatusOK {
|
||
|
|
t.Errorf("first request: got %d, want 200", code)
|
||
|
|
}
|
||
|
|
if code := makeReq(); code != http.StatusTooManyRequests {
|
||
|
|
t.Errorf("second request (same user): got %d, want 429", code)
|
||
|
|
}
|
||
|
|
|
||
|
|
// A different IP (no identity) should have its own bucket
|
||
|
|
r2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
r2.RemoteAddr = "9.9.9.9:9999"
|
||
|
|
w2 := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w2, r2)
|
||
|
|
if w2.Code != http.StatusOK {
|
||
|
|
t.Errorf("different key: got %d, want 200", w2.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── web (root) ────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
func TestWebNew(t *testing.T) {
|
||
|
|
logger := logz.New(logz.Config{Writer: io.Discard})
|
||
|
|
srv := web.New(logger)
|
||
|
|
if srv == nil {
|
||
|
|
t.Fatal("web.New returned nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── helpers ───────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
// errorStore is a RateLimiterStore that always returns an error (for fail-open tests).
|
||
|
|
type errorStore struct{}
|
||
|
|
|
||
|
|
func (e *errorStore) Allow(_ context.Context, _ string) (bool, error) {
|
||
|
|
return false, errors.New("store unavailable")
|
||
|
|
}
|
||
|
|
|
||
|
|
var _ mw.RateLimiterStore = (*errorStore)(nil)
|
||
|
|
|
||
|
|
// io.Discard
|
||
|
|
var _ io.Writer = io.Discard
|