httputil depends on xerrors (Tier 0) and valid (Tier 1), placing it at Tier 2. No infrastructure or lifecycle dependencies exist in this module.
269 lines
7.3 KiB
Go
269 lines
7.3 KiB
Go
package httputil
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"code.nochebuena.dev/go/xerrors"
|
|
)
|
|
|
|
// --- mock validator ---
|
|
|
|
type mockValidator struct{ err error }
|
|
|
|
func (m *mockValidator) Struct(v any) error { return m.err }
|
|
|
|
var okValidator = &mockValidator{}
|
|
|
|
// --- helpers ---
|
|
|
|
func body(s string) io.Reader { return strings.NewReader(s) }
|
|
|
|
func decodeMap(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
|
|
t.Helper()
|
|
var m map[string]any
|
|
if err := json.NewDecoder(rec.Body).Decode(&m); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
return m
|
|
}
|
|
|
|
// --- Handle ---
|
|
|
|
type req struct{ Value string }
|
|
type res struct{ Echo string }
|
|
|
|
func TestHandle_Success(t *testing.T) {
|
|
fn := func(ctx context.Context, r req) (res, error) { return res{Echo: r.Value}, nil }
|
|
h := Handle(okValidator, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{"Value":"hello"}`)))
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("want 200, got %d", rec.Code)
|
|
}
|
|
m := decodeMap(t, rec)
|
|
if m["Echo"] != "hello" {
|
|
t.Errorf("want Echo=hello, got %v", m["Echo"])
|
|
}
|
|
}
|
|
|
|
func TestHandle_InvalidJSON(t *testing.T) {
|
|
fn := func(ctx context.Context, r req) (res, error) { return res{}, nil }
|
|
h := Handle(okValidator, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`not json`)))
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Errorf("want 400, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandle_ValidationFails(t *testing.T) {
|
|
v := &mockValidator{err: xerrors.InvalidInput("field required")}
|
|
fn := func(ctx context.Context, r req) (res, error) { return res{}, nil }
|
|
h := Handle(v, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{}`)))
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Errorf("want 400, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandle_FnError(t *testing.T) {
|
|
fn := func(ctx context.Context, r req) (res, error) {
|
|
return res{}, xerrors.New(xerrors.ErrNotFound, "not found")
|
|
}
|
|
h := Handle(okValidator, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{}`)))
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("want 404, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- HandleNoBody ---
|
|
|
|
func TestHandleNoBody_Success(t *testing.T) {
|
|
fn := func(ctx context.Context) (res, error) { return res{Echo: "ok"}, nil }
|
|
h := HandleNoBody(fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("want 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandleNoBody_FnError(t *testing.T) {
|
|
fn := func(ctx context.Context) (res, error) {
|
|
return res{}, xerrors.New(xerrors.ErrUnavailable, "service down")
|
|
}
|
|
h := HandleNoBody(fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Code != http.StatusServiceUnavailable {
|
|
t.Errorf("want 503, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- HandleEmpty ---
|
|
|
|
func TestHandleEmpty_Success(t *testing.T) {
|
|
fn := func(ctx context.Context, r req) error { return nil }
|
|
h := HandleEmpty(okValidator, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`)))
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Errorf("want 204, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandleEmpty_ValidationFails(t *testing.T) {
|
|
v := &mockValidator{err: xerrors.InvalidInput("required")}
|
|
fn := func(ctx context.Context, r req) error { return nil }
|
|
h := HandleEmpty(v, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`)))
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Errorf("want 400, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandleEmpty_FnError(t *testing.T) {
|
|
fn := func(ctx context.Context, r req) error {
|
|
return xerrors.New(xerrors.ErrPermissionDenied, "forbidden")
|
|
}
|
|
h := HandleEmpty(okValidator, fn)
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`)))
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("want 403, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- HandlerFunc ---
|
|
|
|
func TestHandlerFunc_NoError(t *testing.T) {
|
|
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusOK)
|
|
return nil
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("want 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestHandlerFunc_WithError(t *testing.T) {
|
|
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return xerrors.New(xerrors.ErrUnauthorized, "unauthorized")
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("want 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- JSON / Error / NoContent ---
|
|
|
|
func TestJSON_SetsContentType(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
JSON(rec, http.StatusOK, map[string]string{"k": "v"})
|
|
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
|
|
t.Errorf("Content-Type: want application/json, got %s", ct)
|
|
}
|
|
}
|
|
|
|
func TestJSON_EncodesBody(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
JSON(rec, http.StatusOK, map[string]string{"hello": "world"})
|
|
m := decodeMap(t, rec)
|
|
if m["hello"] != "world" {
|
|
t.Errorf("want hello=world, got %v", m)
|
|
}
|
|
}
|
|
|
|
func TestNoContent_Status(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
NoContent(rec)
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Errorf("want 204, got %d", rec.Code)
|
|
}
|
|
if rec.Body.Len() != 0 {
|
|
t.Errorf("want empty body, got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestError_XerrorsMapping(t *testing.T) {
|
|
cases := []struct {
|
|
code xerrors.Code
|
|
status int
|
|
}{
|
|
{xerrors.ErrInvalidInput, 400},
|
|
{xerrors.ErrUnauthorized, 401},
|
|
{xerrors.ErrPermissionDenied, 403},
|
|
{xerrors.ErrNotFound, 404},
|
|
{xerrors.ErrAlreadyExists, 409},
|
|
{xerrors.ErrInternal, 500},
|
|
{xerrors.ErrNotImplemented, 501},
|
|
{xerrors.ErrUnavailable, 503},
|
|
{xerrors.ErrDeadlineExceeded, 504},
|
|
}
|
|
for _, tc := range cases {
|
|
rec := httptest.NewRecorder()
|
|
Error(rec, xerrors.New(tc.code, "msg"))
|
|
if rec.Code != tc.status {
|
|
t.Errorf("code %s: want %d, got %d", tc.code, tc.status, rec.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestError_UnknownError(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
Error(rec, errors.New("oops"))
|
|
if rec.Code != http.StatusInternalServerError {
|
|
t.Errorf("want 500, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestError_NilError(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
Error(rec, nil)
|
|
if rec.Code != http.StatusInternalServerError {
|
|
t.Errorf("want 500, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestErrorCodeToStatus_AllCodes(t *testing.T) {
|
|
cases := []struct {
|
|
code xerrors.Code
|
|
status int
|
|
}{
|
|
{xerrors.ErrInvalidInput, 400},
|
|
{xerrors.ErrUnauthorized, 401},
|
|
{xerrors.ErrPermissionDenied, 403},
|
|
{xerrors.ErrNotFound, 404},
|
|
{xerrors.ErrAlreadyExists, 409},
|
|
{xerrors.ErrGone, 410},
|
|
{xerrors.ErrPreconditionFailed, 412},
|
|
{xerrors.ErrRateLimited, 429},
|
|
{xerrors.ErrInternal, 500},
|
|
{xerrors.ErrNotImplemented, 501},
|
|
{xerrors.ErrUnavailable, 503},
|
|
{xerrors.ErrDeadlineExceeded, 504},
|
|
{"UNKNOWN_CODE", 500},
|
|
}
|
|
for _, tc := range cases {
|
|
got := errorCodeToStatus(tc.code)
|
|
if got != tc.status {
|
|
t.Errorf("errorCodeToStatus(%s): want %d, got %d", tc.code, tc.status, got)
|
|
}
|
|
}
|
|
}
|