package httpmw import ( "net/http" "net/http/httptest" "testing" "code.nochebuena.dev/go/logz" ) // --- helpers --- func newLogger() logz.Logger { return logz.New(logz.Options{}) } type testLogger struct { last string } func (t *testLogger) Info(msg string, args ...any) { t.last = "info:" + msg } func (t *testLogger) Error(msg string, err error, args ...any) { t.last = "error:" + msg } func (t *testLogger) With(args ...any) Logger { return t } func chain(mw func(http.Handler) http.Handler, h http.HandlerFunc) http.Handler { return mw(h) } // --- Recover --- func TestRecover_NoPanic(t *testing.T) { h := chain(Recover(), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestRecover_Panic(t *testing.T) { h := chain(Recover(), func(w http.ResponseWriter, r *http.Request) { panic("test panic") }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusInternalServerError { t.Errorf("want 500, got %d", rec.Code) } } // --- CORS --- func TestCORS_AllowedOrigin(t *testing.T) { h := chain(CORS([]string{"https://example.com"}), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://example.com") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") != "https://example.com" { t.Errorf("expected CORS header, got %q", rec.Header().Get("Access-Control-Allow-Origin")) } } func TestCORS_AllowAll(t *testing.T) { h := chain(CORS([]string{"*"}), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://any.com") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") == "" { t.Error("expected CORS header for wildcard origin") } } func TestCORS_OPTIONS_Preflight(t *testing.T) { h := chain(CORS([]string{"*"}), func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called for OPTIONS") }) req := httptest.NewRequest(http.MethodOptions, "/", nil) req.Header.Set("Origin", "https://any.com") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Errorf("want 204, got %d", rec.Code) } } // --- RequestID --- func TestRequestID_Generated(t *testing.T) { var capturedID string h := chain(RequestID(func() string { return "test-id-123" }), func(w http.ResponseWriter, r *http.Request) { capturedID = logz.GetRequestID(r.Context()) w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if capturedID != "test-id-123" { t.Errorf("want test-id-123 in context, got %q", capturedID) } } func TestRequestID_CustomGenerator(t *testing.T) { called := false h := chain(RequestID(func() string { called = true; return "custom" }), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if !called { t.Error("custom generator not called") } } func TestRequestID_ContextReadable(t *testing.T) { var id string h := chain(RequestID(func() string { return "abc" }), func(w http.ResponseWriter, r *http.Request) { id = logz.GetRequestID(r.Context()) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if id != "abc" { t.Errorf("logz.GetRequestID: want abc, got %q", id) } } func TestRequestID_HeaderSet(t *testing.T) { h := chain(RequestID(func() string { return "hdr-id" }), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Header().Get("X-Request-ID") != "hdr-id" { t.Errorf("want X-Request-ID=hdr-id, got %q", rec.Header().Get("X-Request-ID")) } } // --- RequestLogger --- func TestRequestLogger_Success(t *testing.T) { lg := &testLogger{} h := chain(RequestLogger(lg), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/foo", nil)) if lg.last != "info:http: request" { t.Errorf("expected info log, got %q", lg.last) } } func TestRequestLogger_Error(t *testing.T) { lg := &testLogger{} h := chain(RequestLogger(lg), func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/err", nil)) if lg.last != "error:http: request" { t.Errorf("expected error log, got %q", lg.last) } } // --- StatusRecorder --- func TestStatusRecorder_Default(t *testing.T) { rec := &StatusRecorder{ResponseWriter: httptest.NewRecorder(), Status: http.StatusOK} if rec.Status != http.StatusOK { t.Errorf("default status: want 200, got %d", rec.Status) } } func TestStatusRecorder_Capture(t *testing.T) { rec := &StatusRecorder{ResponseWriter: httptest.NewRecorder(), Status: http.StatusOK} rec.WriteHeader(http.StatusCreated) if rec.Status != http.StatusCreated { t.Errorf("want 201, got %d", rec.Status) } }