package httputil import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "code.nochebuena.dev/go/xerrors" ) // --- mock validator --- type mockValidator struct{ err error } func (m *mockValidator) Struct(v any) error { return m.err } var okValidator = &mockValidator{} // --- helpers --- func body(s string) io.Reader { return strings.NewReader(s) } func decodeMap(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { t.Helper() var m map[string]any if err := json.NewDecoder(rec.Body).Decode(&m); err != nil { t.Fatalf("decode: %v", err) } return m } // --- Handle --- type req struct{ Value string } type res struct{ Echo string } func TestHandle_Success(t *testing.T) { fn := func(ctx context.Context, r req) (res, error) { return res{Echo: r.Value}, nil } h := Handle(okValidator, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{"Value":"hello"}`))) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } m := decodeMap(t, rec) if m["Echo"] != "hello" { t.Errorf("want Echo=hello, got %v", m["Echo"]) } } func TestHandle_InvalidJSON(t *testing.T) { fn := func(ctx context.Context, r req) (res, error) { return res{}, nil } h := Handle(okValidator, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`not json`))) if rec.Code != http.StatusBadRequest { t.Errorf("want 400, got %d", rec.Code) } } func TestHandle_ValidationFails(t *testing.T) { v := &mockValidator{err: xerrors.InvalidInput("field required")} fn := func(ctx context.Context, r req) (res, error) { return res{}, nil } h := Handle(v, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{}`))) if rec.Code != http.StatusBadRequest { t.Errorf("want 400, got %d", rec.Code) } } func TestHandle_FnError(t *testing.T) { fn := func(ctx context.Context, r req) (res, error) { return res{}, xerrors.New(xerrors.ErrNotFound, "not found") } h := Handle(okValidator, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{}`))) if rec.Code != http.StatusNotFound { t.Errorf("want 404, got %d", rec.Code) } } // --- HandleNoBody --- func TestHandleNoBody_Success(t *testing.T) { fn := func(ctx context.Context) (res, error) { return res{Echo: "ok"}, nil } h := HandleNoBody(fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestHandleNoBody_FnError(t *testing.T) { fn := func(ctx context.Context) (res, error) { return res{}, xerrors.New(xerrors.ErrUnavailable, "service down") } h := HandleNoBody(fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusServiceUnavailable { t.Errorf("want 503, got %d", rec.Code) } } // --- HandleEmpty --- func TestHandleEmpty_Success(t *testing.T) { fn := func(ctx context.Context, r req) error { return nil } h := HandleEmpty(okValidator, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`))) if rec.Code != http.StatusNoContent { t.Errorf("want 204, got %d", rec.Code) } } func TestHandleEmpty_ValidationFails(t *testing.T) { v := &mockValidator{err: xerrors.InvalidInput("required")} fn := func(ctx context.Context, r req) error { return nil } h := HandleEmpty(v, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`))) if rec.Code != http.StatusBadRequest { t.Errorf("want 400, got %d", rec.Code) } } func TestHandleEmpty_FnError(t *testing.T) { fn := func(ctx context.Context, r req) error { return xerrors.New(xerrors.ErrPermissionDenied, "forbidden") } h := HandleEmpty(okValidator, fn) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`))) if rec.Code != http.StatusForbidden { t.Errorf("want 403, got %d", rec.Code) } } // --- HandlerFunc --- func TestHandlerFunc_NoError(t *testing.T) { h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { w.WriteHeader(http.StatusOK) return nil }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestHandlerFunc_WithError(t *testing.T) { h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return xerrors.New(xerrors.ErrUnauthorized, "unauthorized") }) rec := httptest.NewRecorder() h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) if rec.Code != http.StatusUnauthorized { t.Errorf("want 401, got %d", rec.Code) } } // --- JSON / Error / NoContent --- func TestJSON_SetsContentType(t *testing.T) { rec := httptest.NewRecorder() JSON(rec, http.StatusOK, map[string]string{"k": "v"}) if ct := rec.Header().Get("Content-Type"); ct != "application/json" { t.Errorf("Content-Type: want application/json, got %s", ct) } } func TestJSON_EncodesBody(t *testing.T) { rec := httptest.NewRecorder() JSON(rec, http.StatusOK, map[string]string{"hello": "world"}) m := decodeMap(t, rec) if m["hello"] != "world" { t.Errorf("want hello=world, got %v", m) } } func TestNoContent_Status(t *testing.T) { rec := httptest.NewRecorder() NoContent(rec) if rec.Code != http.StatusNoContent { t.Errorf("want 204, got %d", rec.Code) } if rec.Body.Len() != 0 { t.Errorf("want empty body, got %q", rec.Body.String()) } } func TestError_XerrorsMapping(t *testing.T) { cases := []struct { code xerrors.Code status int }{ {xerrors.ErrInvalidInput, 400}, {xerrors.ErrUnauthorized, 401}, {xerrors.ErrPermissionDenied, 403}, {xerrors.ErrNotFound, 404}, {xerrors.ErrAlreadyExists, 409}, {xerrors.ErrInternal, 500}, {xerrors.ErrNotImplemented, 501}, {xerrors.ErrUnavailable, 503}, {xerrors.ErrDeadlineExceeded, 504}, } for _, tc := range cases { rec := httptest.NewRecorder() Error(rec, xerrors.New(tc.code, "msg")) if rec.Code != tc.status { t.Errorf("code %s: want %d, got %d", tc.code, tc.status, rec.Code) } } } func TestError_UnknownError(t *testing.T) { rec := httptest.NewRecorder() Error(rec, errors.New("oops")) if rec.Code != http.StatusInternalServerError { t.Errorf("want 500, got %d", rec.Code) } } func TestError_NilError(t *testing.T) { rec := httptest.NewRecorder() Error(rec, nil) if rec.Code != http.StatusInternalServerError { t.Errorf("want 500, got %d", rec.Code) } } func TestErrorCodeToStatus_AllCodes(t *testing.T) { cases := []struct { code xerrors.Code status int }{ {xerrors.ErrInvalidInput, 400}, {xerrors.ErrUnauthorized, 401}, {xerrors.ErrPermissionDenied, 403}, {xerrors.ErrNotFound, 404}, {xerrors.ErrAlreadyExists, 409}, {xerrors.ErrGone, 410}, {xerrors.ErrPreconditionFailed, 412}, {xerrors.ErrRateLimited, 429}, {xerrors.ErrInternal, 500}, {xerrors.ErrNotImplemented, 501}, {xerrors.ErrUnavailable, 503}, {xerrors.ErrDeadlineExceeded, 504}, {"UNKNOWN_CODE", 500}, } for _, tc := range cases { got := errorCodeToStatus(tc.code) if got != tc.status { t.Errorf("errorCodeToStatus(%s): want %d, got %d", tc.code, tc.status, got) } } }