269 lines
7.3 KiB
Go
269 lines
7.3 KiB
Go
|
|
package httputil
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/go/xerrors"
|
||
|
|
)
|
||
|
|
|
||
|
|
// --- mock validator ---
|
||
|
|
|
||
|
|
type mockValidator struct{ err error }
|
||
|
|
|
||
|
|
func (m *mockValidator) Struct(v any) error { return m.err }
|
||
|
|
|
||
|
|
var okValidator = &mockValidator{}
|
||
|
|
|
||
|
|
// --- helpers ---
|
||
|
|
|
||
|
|
func body(s string) io.Reader { return strings.NewReader(s) }
|
||
|
|
|
||
|
|
func decodeMap(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
|
||
|
|
t.Helper()
|
||
|
|
var m map[string]any
|
||
|
|
if err := json.NewDecoder(rec.Body).Decode(&m); err != nil {
|
||
|
|
t.Fatalf("decode: %v", err)
|
||
|
|
}
|
||
|
|
return m
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Handle ---
|
||
|
|
|
||
|
|
type req struct{ Value string }
|
||
|
|
type res struct{ Echo string }
|
||
|
|
|
||
|
|
func TestHandle_Success(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context, r req) (res, error) { return res{Echo: r.Value}, nil }
|
||
|
|
h := Handle(okValidator, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{"Value":"hello"}`)))
|
||
|
|
if rec.Code != http.StatusOK {
|
||
|
|
t.Errorf("want 200, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
m := decodeMap(t, rec)
|
||
|
|
if m["Echo"] != "hello" {
|
||
|
|
t.Errorf("want Echo=hello, got %v", m["Echo"])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandle_InvalidJSON(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context, r req) (res, error) { return res{}, nil }
|
||
|
|
h := Handle(okValidator, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`not json`)))
|
||
|
|
if rec.Code != http.StatusBadRequest {
|
||
|
|
t.Errorf("want 400, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandle_ValidationFails(t *testing.T) {
|
||
|
|
v := &mockValidator{err: xerrors.InvalidInput("field required")}
|
||
|
|
fn := func(ctx context.Context, r req) (res, error) { return res{}, nil }
|
||
|
|
h := Handle(v, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{}`)))
|
||
|
|
if rec.Code != http.StatusBadRequest {
|
||
|
|
t.Errorf("want 400, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandle_FnError(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context, r req) (res, error) {
|
||
|
|
return res{}, xerrors.New(xerrors.ErrNotFound, "not found")
|
||
|
|
}
|
||
|
|
h := Handle(okValidator, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/", body(`{}`)))
|
||
|
|
if rec.Code != http.StatusNotFound {
|
||
|
|
t.Errorf("want 404, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- HandleNoBody ---
|
||
|
|
|
||
|
|
func TestHandleNoBody_Success(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context) (res, error) { return res{Echo: "ok"}, nil }
|
||
|
|
h := HandleNoBody(fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
|
if rec.Code != http.StatusOK {
|
||
|
|
t.Errorf("want 200, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleNoBody_FnError(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context) (res, error) {
|
||
|
|
return res{}, xerrors.New(xerrors.ErrUnavailable, "service down")
|
||
|
|
}
|
||
|
|
h := HandleNoBody(fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
|
if rec.Code != http.StatusServiceUnavailable {
|
||
|
|
t.Errorf("want 503, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- HandleEmpty ---
|
||
|
|
|
||
|
|
func TestHandleEmpty_Success(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context, r req) error { return nil }
|
||
|
|
h := HandleEmpty(okValidator, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`)))
|
||
|
|
if rec.Code != http.StatusNoContent {
|
||
|
|
t.Errorf("want 204, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleEmpty_ValidationFails(t *testing.T) {
|
||
|
|
v := &mockValidator{err: xerrors.InvalidInput("required")}
|
||
|
|
fn := func(ctx context.Context, r req) error { return nil }
|
||
|
|
h := HandleEmpty(v, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`)))
|
||
|
|
if rec.Code != http.StatusBadRequest {
|
||
|
|
t.Errorf("want 400, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleEmpty_FnError(t *testing.T) {
|
||
|
|
fn := func(ctx context.Context, r req) error {
|
||
|
|
return xerrors.New(xerrors.ErrPermissionDenied, "forbidden")
|
||
|
|
}
|
||
|
|
h := HandleEmpty(okValidator, fn)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodDelete, "/", body(`{}`)))
|
||
|
|
if rec.Code != http.StatusForbidden {
|
||
|
|
t.Errorf("want 403, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- HandlerFunc ---
|
||
|
|
|
||
|
|
func TestHandlerFunc_NoError(t *testing.T) {
|
||
|
|
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
|
if rec.Code != http.StatusOK {
|
||
|
|
t.Errorf("want 200, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandlerFunc_WithError(t *testing.T) {
|
||
|
|
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||
|
|
return xerrors.New(xerrors.ErrUnauthorized, "unauthorized")
|
||
|
|
})
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
|
||
|
|
if rec.Code != http.StatusUnauthorized {
|
||
|
|
t.Errorf("want 401, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- JSON / Error / NoContent ---
|
||
|
|
|
||
|
|
func TestJSON_SetsContentType(t *testing.T) {
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
JSON(rec, http.StatusOK, map[string]string{"k": "v"})
|
||
|
|
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
|
||
|
|
t.Errorf("Content-Type: want application/json, got %s", ct)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestJSON_EncodesBody(t *testing.T) {
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
JSON(rec, http.StatusOK, map[string]string{"hello": "world"})
|
||
|
|
m := decodeMap(t, rec)
|
||
|
|
if m["hello"] != "world" {
|
||
|
|
t.Errorf("want hello=world, got %v", m)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestNoContent_Status(t *testing.T) {
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
NoContent(rec)
|
||
|
|
if rec.Code != http.StatusNoContent {
|
||
|
|
t.Errorf("want 204, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
if rec.Body.Len() != 0 {
|
||
|
|
t.Errorf("want empty body, got %q", rec.Body.String())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestError_XerrorsMapping(t *testing.T) {
|
||
|
|
cases := []struct {
|
||
|
|
code xerrors.Code
|
||
|
|
status int
|
||
|
|
}{
|
||
|
|
{xerrors.ErrInvalidInput, 400},
|
||
|
|
{xerrors.ErrUnauthorized, 401},
|
||
|
|
{xerrors.ErrPermissionDenied, 403},
|
||
|
|
{xerrors.ErrNotFound, 404},
|
||
|
|
{xerrors.ErrAlreadyExists, 409},
|
||
|
|
{xerrors.ErrInternal, 500},
|
||
|
|
{xerrors.ErrNotImplemented, 501},
|
||
|
|
{xerrors.ErrUnavailable, 503},
|
||
|
|
{xerrors.ErrDeadlineExceeded, 504},
|
||
|
|
}
|
||
|
|
for _, tc := range cases {
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
Error(rec, xerrors.New(tc.code, "msg"))
|
||
|
|
if rec.Code != tc.status {
|
||
|
|
t.Errorf("code %s: want %d, got %d", tc.code, tc.status, rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestError_UnknownError(t *testing.T) {
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
Error(rec, errors.New("oops"))
|
||
|
|
if rec.Code != http.StatusInternalServerError {
|
||
|
|
t.Errorf("want 500, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestError_NilError(t *testing.T) {
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
Error(rec, nil)
|
||
|
|
if rec.Code != http.StatusInternalServerError {
|
||
|
|
t.Errorf("want 500, got %d", rec.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestErrorCodeToStatus_AllCodes(t *testing.T) {
|
||
|
|
cases := []struct {
|
||
|
|
code xerrors.Code
|
||
|
|
status int
|
||
|
|
}{
|
||
|
|
{xerrors.ErrInvalidInput, 400},
|
||
|
|
{xerrors.ErrUnauthorized, 401},
|
||
|
|
{xerrors.ErrPermissionDenied, 403},
|
||
|
|
{xerrors.ErrNotFound, 404},
|
||
|
|
{xerrors.ErrAlreadyExists, 409},
|
||
|
|
{xerrors.ErrGone, 410},
|
||
|
|
{xerrors.ErrPreconditionFailed, 412},
|
||
|
|
{xerrors.ErrRateLimited, 429},
|
||
|
|
{xerrors.ErrInternal, 500},
|
||
|
|
{xerrors.ErrNotImplemented, 501},
|
||
|
|
{xerrors.ErrUnavailable, 503},
|
||
|
|
{xerrors.ErrDeadlineExceeded, 504},
|
||
|
|
{"UNKNOWN_CODE", 500},
|
||
|
|
}
|
||
|
|
for _, tc := range cases {
|
||
|
|
got := errorCodeToStatus(tc.code)
|
||
|
|
if got != tc.status {
|
||
|
|
t.Errorf("errorCodeToStatus(%s): want %d, got %d", tc.code, tc.status, got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|