package health import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "testing" "time" ) // --- mock helpers --- type mockCheck struct { name string priority Level err error delay time.Duration } func (m *mockCheck) HealthCheck(ctx context.Context) error { if m.delay > 0 { select { case <-time.After(m.delay): case <-ctx.Done(): return ctx.Err() } } return m.err } func (m *mockCheck) Name() string { return m.name } func (m *mockCheck) Priority() Level { return m.priority } type noopLogger struct{} func (n *noopLogger) Debug(msg string, args ...any) {} func (n *noopLogger) Info(msg string, args ...any) {} func (n *noopLogger) Warn(msg string, args ...any) {} func (n *noopLogger) Error(msg string, err error, args ...any) {} func (n *noopLogger) WithContext(ctx context.Context) Logger { return n } func doRequest(t *testing.T, h http.Handler) (int, Response) { t.Helper() req := httptest.NewRequest(http.MethodGet, "/health", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) var resp Response if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } return rec.Code, resp } // --- tests --- func TestHandler_NoChecks(t *testing.T) { h := NewHandler(&noopLogger{}) code, resp := doRequest(t, h) if code != http.StatusOK { t.Errorf("want 200, got %d", code) } if resp.Status != "UP" { t.Errorf("want UP, got %s", resp.Status) } } func TestHandler_AllUp(t *testing.T) { h := NewHandler(&noopLogger{}, &mockCheck{name: "db", priority: LevelCritical}, &mockCheck{name: "cache", priority: LevelDegraded}, ) code, resp := doRequest(t, h) if code != http.StatusOK { t.Errorf("want 200, got %d", code) } if resp.Status != "UP" { t.Errorf("want UP, got %s", resp.Status) } if resp.Components["db"].Status != "UP" { t.Errorf("db: want UP, got %s", resp.Components["db"].Status) } } func TestHandler_CriticalDown(t *testing.T) { h := NewHandler(&noopLogger{}, &mockCheck{name: "db", priority: LevelCritical, err: errors.New("connection refused")}, ) code, resp := doRequest(t, h) if code != http.StatusServiceUnavailable { t.Errorf("want 503, got %d", code) } if resp.Status != "DOWN" { t.Errorf("want DOWN, got %s", resp.Status) } if resp.Components["db"].Status != "DOWN" { t.Errorf("db: want DOWN, got %s", resp.Components["db"].Status) } } func TestHandler_DegradedDown(t *testing.T) { h := NewHandler(&noopLogger{}, &mockCheck{name: "cache", priority: LevelDegraded, err: errors.New("timeout")}, ) code, resp := doRequest(t, h) if code != http.StatusOK { t.Errorf("want 200, got %d", code) } if resp.Status != "DEGRADED" { t.Errorf("want DEGRADED, got %s", resp.Status) } if resp.Components["cache"].Status != "DEGRADED" { t.Errorf("cache: want DEGRADED, got %s", resp.Components["cache"].Status) } } func TestHandler_MixedDown(t *testing.T) { h := NewHandler(&noopLogger{}, &mockCheck{name: "db", priority: LevelCritical, err: errors.New("down")}, &mockCheck{name: "cache", priority: LevelDegraded, err: errors.New("down")}, ) code, resp := doRequest(t, h) if code != http.StatusServiceUnavailable { t.Errorf("want 503, got %d", code) } if resp.Status != "DOWN" { t.Errorf("want DOWN, got %s", resp.Status) } } func TestHandler_ChecksParallel(t *testing.T) { delay := 100 * time.Millisecond h := NewHandler(&noopLogger{}, &mockCheck{name: "a", priority: LevelCritical, delay: delay}, &mockCheck{name: "b", priority: LevelCritical, delay: delay}, &mockCheck{name: "c", priority: LevelCritical, delay: delay}, ) start := time.Now() doRequest(t, h) elapsed := time.Since(start) // parallel: should complete in ~delay, not 3*delay if elapsed > 3*delay { t.Errorf("checks do not appear to run in parallel: elapsed %v", elapsed) } } func TestHandler_JSON_Shape(t *testing.T) { h := NewHandler(&noopLogger{}, &mockCheck{name: "db", priority: LevelCritical}, ) req := httptest.NewRequest(http.MethodGet, "/health", nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, req) if ct := rec.Header().Get("Content-Type"); ct != "application/json" { t.Errorf("Content-Type: want application/json, got %s", ct) } var resp Response if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("body is not valid JSON: %v", err) } if _, ok := resp.Components["db"]; !ok { t.Error("components map missing 'db' key") } } func TestHandler_ContextTimeout(t *testing.T) { // Check that times out faster than the 5s global timeout when client cancels. h := NewHandler(&noopLogger{}, &mockCheck{name: "slow", priority: LevelCritical, delay: 10 * time.Second}, ) req := httptest.NewRequest(http.MethodGet, "/health", nil) ctx, cancel := context.WithTimeout(req.Context(), 50*time.Millisecond) defer cancel() req = req.WithContext(ctx) rec := httptest.NewRecorder() start := time.Now() h.ServeHTTP(rec, req) elapsed := time.Since(start) if elapsed > time.Second { t.Errorf("handler did not respect context timeout: elapsed %v", elapsed) } }