189 lines
5.1 KiB
Go
189 lines
5.1 KiB
Go
|
|
package health
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
// --- mock helpers ---
|
||
|
|
|
||
|
|
type mockCheck struct {
|
||
|
|
name string
|
||
|
|
priority Level
|
||
|
|
err error
|
||
|
|
delay time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockCheck) HealthCheck(ctx context.Context) error {
|
||
|
|
if m.delay > 0 {
|
||
|
|
select {
|
||
|
|
case <-time.After(m.delay):
|
||
|
|
case <-ctx.Done():
|
||
|
|
return ctx.Err()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return m.err
|
||
|
|
}
|
||
|
|
func (m *mockCheck) Name() string { return m.name }
|
||
|
|
func (m *mockCheck) Priority() Level { return m.priority }
|
||
|
|
|
||
|
|
type noopLogger struct{}
|
||
|
|
|
||
|
|
func (n *noopLogger) Debug(msg string, args ...any) {}
|
||
|
|
func (n *noopLogger) Info(msg string, args ...any) {}
|
||
|
|
func (n *noopLogger) Warn(msg string, args ...any) {}
|
||
|
|
func (n *noopLogger) Error(msg string, err error, args ...any) {}
|
||
|
|
func (n *noopLogger) WithContext(ctx context.Context) Logger { return n }
|
||
|
|
|
||
|
|
func doRequest(t *testing.T, h http.Handler) (int, Response) {
|
||
|
|
t.Helper()
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, req)
|
||
|
|
|
||
|
|
var resp Response
|
||
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||
|
|
t.Fatalf("decode response: %v", err)
|
||
|
|
}
|
||
|
|
return rec.Code, resp
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- tests ---
|
||
|
|
|
||
|
|
func TestHandler_NoChecks(t *testing.T) {
|
||
|
|
h := NewHandler(&noopLogger{})
|
||
|
|
code, resp := doRequest(t, h)
|
||
|
|
if code != http.StatusOK {
|
||
|
|
t.Errorf("want 200, got %d", code)
|
||
|
|
}
|
||
|
|
if resp.Status != "UP" {
|
||
|
|
t.Errorf("want UP, got %s", resp.Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_AllUp(t *testing.T) {
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "db", priority: LevelCritical},
|
||
|
|
&mockCheck{name: "cache", priority: LevelDegraded},
|
||
|
|
)
|
||
|
|
code, resp := doRequest(t, h)
|
||
|
|
if code != http.StatusOK {
|
||
|
|
t.Errorf("want 200, got %d", code)
|
||
|
|
}
|
||
|
|
if resp.Status != "UP" {
|
||
|
|
t.Errorf("want UP, got %s", resp.Status)
|
||
|
|
}
|
||
|
|
if resp.Components["db"].Status != "UP" {
|
||
|
|
t.Errorf("db: want UP, got %s", resp.Components["db"].Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_CriticalDown(t *testing.T) {
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "db", priority: LevelCritical, err: errors.New("connection refused")},
|
||
|
|
)
|
||
|
|
code, resp := doRequest(t, h)
|
||
|
|
if code != http.StatusServiceUnavailable {
|
||
|
|
t.Errorf("want 503, got %d", code)
|
||
|
|
}
|
||
|
|
if resp.Status != "DOWN" {
|
||
|
|
t.Errorf("want DOWN, got %s", resp.Status)
|
||
|
|
}
|
||
|
|
if resp.Components["db"].Status != "DOWN" {
|
||
|
|
t.Errorf("db: want DOWN, got %s", resp.Components["db"].Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_DegradedDown(t *testing.T) {
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "cache", priority: LevelDegraded, err: errors.New("timeout")},
|
||
|
|
)
|
||
|
|
code, resp := doRequest(t, h)
|
||
|
|
if code != http.StatusOK {
|
||
|
|
t.Errorf("want 200, got %d", code)
|
||
|
|
}
|
||
|
|
if resp.Status != "DEGRADED" {
|
||
|
|
t.Errorf("want DEGRADED, got %s", resp.Status)
|
||
|
|
}
|
||
|
|
if resp.Components["cache"].Status != "DEGRADED" {
|
||
|
|
t.Errorf("cache: want DEGRADED, got %s", resp.Components["cache"].Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_MixedDown(t *testing.T) {
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "db", priority: LevelCritical, err: errors.New("down")},
|
||
|
|
&mockCheck{name: "cache", priority: LevelDegraded, err: errors.New("down")},
|
||
|
|
)
|
||
|
|
code, resp := doRequest(t, h)
|
||
|
|
if code != http.StatusServiceUnavailable {
|
||
|
|
t.Errorf("want 503, got %d", code)
|
||
|
|
}
|
||
|
|
if resp.Status != "DOWN" {
|
||
|
|
t.Errorf("want DOWN, got %s", resp.Status)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_ChecksParallel(t *testing.T) {
|
||
|
|
delay := 100 * time.Millisecond
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "a", priority: LevelCritical, delay: delay},
|
||
|
|
&mockCheck{name: "b", priority: LevelCritical, delay: delay},
|
||
|
|
&mockCheck{name: "c", priority: LevelCritical, delay: delay},
|
||
|
|
)
|
||
|
|
start := time.Now()
|
||
|
|
doRequest(t, h)
|
||
|
|
elapsed := time.Since(start)
|
||
|
|
|
||
|
|
// parallel: should complete in ~delay, not 3*delay
|
||
|
|
if elapsed > 3*delay {
|
||
|
|
t.Errorf("checks do not appear to run in parallel: elapsed %v", elapsed)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_JSON_Shape(t *testing.T) {
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "db", priority: LevelCritical},
|
||
|
|
)
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
h.ServeHTTP(rec, req)
|
||
|
|
|
||
|
|
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
|
||
|
|
t.Errorf("Content-Type: want application/json, got %s", ct)
|
||
|
|
}
|
||
|
|
|
||
|
|
var resp Response
|
||
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||
|
|
t.Fatalf("body is not valid JSON: %v", err)
|
||
|
|
}
|
||
|
|
if _, ok := resp.Components["db"]; !ok {
|
||
|
|
t.Error("components map missing 'db' key")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandler_ContextTimeout(t *testing.T) {
|
||
|
|
// Check that times out faster than the 5s global timeout when client cancels.
|
||
|
|
h := NewHandler(&noopLogger{},
|
||
|
|
&mockCheck{name: "slow", priority: LevelCritical, delay: 10 * time.Second},
|
||
|
|
)
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||
|
|
ctx, cancel := context.WithTimeout(req.Context(), 50*time.Millisecond)
|
||
|
|
defer cancel()
|
||
|
|
req = req.WithContext(ctx)
|
||
|
|
|
||
|
|
rec := httptest.NewRecorder()
|
||
|
|
start := time.Now()
|
||
|
|
h.ServeHTTP(rec, req)
|
||
|
|
elapsed := time.Since(start)
|
||
|
|
|
||
|
|
if elapsed > time.Second {
|
||
|
|
t.Errorf("handler did not respect context timeout: elapsed %v", elapsed)
|
||
|
|
}
|
||
|
|
}
|