Standalone net/http middleware for panic recovery, CORS, request ID injection, and request logging. What's included: - Recover(): panic -> 500, captures debug.Stack, no logger required - CORS(origins): OPTIONS 204 preflight, origin allowlist, package-wide method/header constants - RequestID(generator): injects ID via logz.WithRequestID, sets X-Request-ID response header - RequestLogger(logger): logs method/path/status/latency/request_id; Error for 5xx, Info otherwise - Logger interface: Info, Error, With — duck-typed; satisfied by logz.Logger - StatusRecorder: exported http.ResponseWriter wrapper that captures written status code - Direct logz import for context helpers (documented exception to ADR-001) Tested-via: todo-api POC integration Reviewed-against: docs/adr/
184 lines
5.5 KiB
Go
184 lines
5.5 KiB
Go
package httpmw
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"code.nochebuena.dev/go/logz"
|
|
)
|
|
|
|
// --- helpers ---
|
|
|
|
func newLogger() logz.Logger { return logz.New(logz.Options{}) }
|
|
|
|
type testLogger struct {
|
|
last string
|
|
}
|
|
|
|
func (t *testLogger) Info(msg string, args ...any) { t.last = "info:" + msg }
|
|
func (t *testLogger) Error(msg string, err error, args ...any) { t.last = "error:" + msg }
|
|
func (t *testLogger) With(args ...any) Logger { return t }
|
|
|
|
func chain(mw func(http.Handler) http.Handler, h http.HandlerFunc) http.Handler {
|
|
return mw(h)
|
|
}
|
|
|
|
// --- Recover ---
|
|
|
|
func TestRecover_NoPanic(t *testing.T) {
|
|
h := chain(Recover(), func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("want 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRecover_Panic(t *testing.T) {
|
|
h := chain(Recover(), func(w http.ResponseWriter, r *http.Request) {
|
|
panic("test panic")
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Code != http.StatusInternalServerError {
|
|
t.Errorf("want 500, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- CORS ---
|
|
|
|
func TestCORS_AllowedOrigin(t *testing.T) {
|
|
h := chain(CORS([]string{"https://example.com"}), func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "https://example.com" {
|
|
t.Errorf("expected CORS header, got %q", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestCORS_AllowAll(t *testing.T) {
|
|
h := chain(CORS([]string{"*"}), func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://any.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Header().Get("Access-Control-Allow-Origin") == "" {
|
|
t.Error("expected CORS header for wildcard origin")
|
|
}
|
|
}
|
|
|
|
func TestCORS_OPTIONS_Preflight(t *testing.T) {
|
|
h := chain(CORS([]string{"*"}), func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called for OPTIONS")
|
|
})
|
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
|
req.Header.Set("Origin", "https://any.com")
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Errorf("want 204, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- RequestID ---
|
|
|
|
func TestRequestID_Generated(t *testing.T) {
|
|
var capturedID string
|
|
h := chain(RequestID(func() string { return "test-id-123" }),
|
|
func(w http.ResponseWriter, r *http.Request) {
|
|
capturedID = logz.GetRequestID(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if capturedID != "test-id-123" {
|
|
t.Errorf("want test-id-123 in context, got %q", capturedID)
|
|
}
|
|
}
|
|
|
|
func TestRequestID_CustomGenerator(t *testing.T) {
|
|
called := false
|
|
h := chain(RequestID(func() string { called = true; return "custom" }),
|
|
func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if !called {
|
|
t.Error("custom generator not called")
|
|
}
|
|
}
|
|
|
|
func TestRequestID_ContextReadable(t *testing.T) {
|
|
var id string
|
|
h := chain(RequestID(func() string { return "abc" }),
|
|
func(w http.ResponseWriter, r *http.Request) {
|
|
id = logz.GetRequestID(r.Context())
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if id != "abc" {
|
|
t.Errorf("logz.GetRequestID: want abc, got %q", id)
|
|
}
|
|
}
|
|
|
|
func TestRequestID_HeaderSet(t *testing.T) {
|
|
h := chain(RequestID(func() string { return "hdr-id" }),
|
|
func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Header().Get("X-Request-ID") != "hdr-id" {
|
|
t.Errorf("want X-Request-ID=hdr-id, got %q", rec.Header().Get("X-Request-ID"))
|
|
}
|
|
}
|
|
|
|
// --- RequestLogger ---
|
|
|
|
func TestRequestLogger_Success(t *testing.T) {
|
|
lg := &testLogger{}
|
|
h := chain(RequestLogger(lg), func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/foo", nil))
|
|
if lg.last != "info:http: request" {
|
|
t.Errorf("expected info log, got %q", lg.last)
|
|
}
|
|
}
|
|
|
|
func TestRequestLogger_Error(t *testing.T) {
|
|
lg := &testLogger{}
|
|
h := chain(RequestLogger(lg), func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/err", nil))
|
|
if lg.last != "error:http: request" {
|
|
t.Errorf("expected error log, got %q", lg.last)
|
|
}
|
|
}
|
|
|
|
// --- StatusRecorder ---
|
|
|
|
func TestStatusRecorder_Default(t *testing.T) {
|
|
rec := &StatusRecorder{ResponseWriter: httptest.NewRecorder(), Status: http.StatusOK}
|
|
if rec.Status != http.StatusOK {
|
|
t.Errorf("default status: want 200, got %d", rec.Status)
|
|
}
|
|
}
|
|
|
|
func TestStatusRecorder_Capture(t *testing.T) {
|
|
rec := &StatusRecorder{ResponseWriter: httptest.NewRecorder(), Status: http.StatusOK}
|
|
rec.WriteHeader(http.StatusCreated)
|
|
if rec.Status != http.StatusCreated {
|
|
t.Errorf("want 201, got %d", rec.Status)
|
|
}
|
|
}
|