Files
httpmw/httpmw_test.go
Rene Nochebuena ad2a9e465e feat(httpmw): initial stable release v0.9.0
Standalone net/http middleware for panic recovery, CORS, request ID injection, and request logging.

What's included:
- Recover(): panic -> 500, captures debug.Stack, no logger required
- CORS(origins): OPTIONS 204 preflight, origin allowlist, package-wide method/header constants
- RequestID(generator): injects ID via logz.WithRequestID, sets X-Request-ID response header
- RequestLogger(logger): logs method/path/status/latency/request_id; Error for 5xx, Info otherwise
- Logger interface: Info, Error, With — duck-typed; satisfied by logz.Logger
- StatusRecorder: exported http.ResponseWriter wrapper that captures written status code
- Direct logz import for context helpers (documented exception to ADR-001)

Tested-via: todo-api POC integration
Reviewed-against: docs/adr/
2026-03-19 00:03:24 +00:00

184 lines
5.5 KiB
Go

package httpmw
import (
"net/http"
"net/http/httptest"
"testing"
"code.nochebuena.dev/go/logz"
)
// --- helpers ---
func newLogger() logz.Logger { return logz.New(logz.Options{}) }
type testLogger struct {
last string
}
func (t *testLogger) Info(msg string, args ...any) { t.last = "info:" + msg }
func (t *testLogger) Error(msg string, err error, args ...any) { t.last = "error:" + msg }
func (t *testLogger) With(args ...any) Logger { return t }
func chain(mw func(http.Handler) http.Handler, h http.HandlerFunc) http.Handler {
return mw(h)
}
// --- Recover ---
func TestRecover_NoPanic(t *testing.T) {
h := chain(Recover(), func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestRecover_Panic(t *testing.T) {
h := chain(Recover(), func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusInternalServerError {
t.Errorf("want 500, got %d", rec.Code)
}
}
// --- CORS ---
func TestCORS_AllowedOrigin(t *testing.T) {
h := chain(CORS([]string{"https://example.com"}), func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://example.com")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Header().Get("Access-Control-Allow-Origin") != "https://example.com" {
t.Errorf("expected CORS header, got %q", rec.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORS_AllowAll(t *testing.T) {
h := chain(CORS([]string{"*"}), func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://any.com")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Header().Get("Access-Control-Allow-Origin") == "" {
t.Error("expected CORS header for wildcard origin")
}
}
func TestCORS_OPTIONS_Preflight(t *testing.T) {
h := chain(CORS([]string{"*"}), func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called for OPTIONS")
})
req := httptest.NewRequest(http.MethodOptions, "/", nil)
req.Header.Set("Origin", "https://any.com")
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Errorf("want 204, got %d", rec.Code)
}
}
// --- RequestID ---
func TestRequestID_Generated(t *testing.T) {
var capturedID string
h := chain(RequestID(func() string { return "test-id-123" }),
func(w http.ResponseWriter, r *http.Request) {
capturedID = logz.GetRequestID(r.Context())
w.WriteHeader(http.StatusOK)
})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if capturedID != "test-id-123" {
t.Errorf("want test-id-123 in context, got %q", capturedID)
}
}
func TestRequestID_CustomGenerator(t *testing.T) {
called := false
h := chain(RequestID(func() string { called = true; return "custom" }),
func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if !called {
t.Error("custom generator not called")
}
}
func TestRequestID_ContextReadable(t *testing.T) {
var id string
h := chain(RequestID(func() string { return "abc" }),
func(w http.ResponseWriter, r *http.Request) {
id = logz.GetRequestID(r.Context())
})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if id != "abc" {
t.Errorf("logz.GetRequestID: want abc, got %q", id)
}
}
func TestRequestID_HeaderSet(t *testing.T) {
h := chain(RequestID(func() string { return "hdr-id" }),
func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Header().Get("X-Request-ID") != "hdr-id" {
t.Errorf("want X-Request-ID=hdr-id, got %q", rec.Header().Get("X-Request-ID"))
}
}
// --- RequestLogger ---
func TestRequestLogger_Success(t *testing.T) {
lg := &testLogger{}
h := chain(RequestLogger(lg), func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/foo", nil))
if lg.last != "info:http: request" {
t.Errorf("expected info log, got %q", lg.last)
}
}
func TestRequestLogger_Error(t *testing.T) {
lg := &testLogger{}
h := chain(RequestLogger(lg), func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/err", nil))
if lg.last != "error:http: request" {
t.Errorf("expected error log, got %q", lg.last)
}
}
// --- StatusRecorder ---
func TestStatusRecorder_Default(t *testing.T) {
rec := &StatusRecorder{ResponseWriter: httptest.NewRecorder(), Status: http.StatusOK}
if rec.Status != http.StatusOK {
t.Errorf("default status: want 200, got %d", rec.Status)
}
}
func TestStatusRecorder_Capture(t *testing.T) {
rec := &StatusRecorder{ResponseWriter: httptest.NewRecorder(), Status: http.StatusOK}
rec.WriteHeader(http.StatusCreated)
if rec.Status != http.StatusCreated {
t.Errorf("want 201, got %d", rec.Status)
}
}