package httpserver import ( "context" "fmt" "net" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" ) // --- helpers --- type testLogger struct{ last string } func (l *testLogger) Info(msg string, args ...any) { l.last = "info:" + msg } func (l *testLogger) Error(msg string, err error, args ...any) { l.last = "error:" + msg } func newLogger() *testLogger { return &testLogger{} } // freePort returns a random available TCP port. func freePort(t *testing.T) int { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } port := ln.Addr().(*net.TCPAddr).Port ln.Close() return port } // --- Tests --- func TestNew(t *testing.T) { srv := New(newLogger(), Config{}) if srv == nil { t.Fatal("New returned nil") } } func TestNew_ImplementsLauncherComponent(t *testing.T) { srv := New(newLogger(), Config{}) // Verify launcher.Component methods exist (compile-time checked in compliance_test.go) if err := srv.OnInit(); err != nil { t.Fatalf("OnInit: %v", err) } } func TestNew_WithMiddleware(t *testing.T) { called := false mw := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true next.ServeHTTP(w, r) }) } srv := New(newLogger(), Config{}, WithMiddleware(mw)) if err := srv.OnInit(); err != nil { t.Fatalf("OnInit: %v", err) } srv.Get("/test", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() srv.ServeHTTP(rec, req) if !called { t.Error("WithMiddleware: middleware was not called") } } func TestComponent_OnInit(t *testing.T) { srv := New(newLogger(), Config{}) if err := srv.OnInit(); err != nil { t.Errorf("OnInit returned error: %v", err) } } func TestComponent_Routes(t *testing.T) { srv := New(newLogger(), Config{}) if err := srv.OnInit(); err != nil { t.Fatal(err) } srv.Get("/ping", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/ping", nil) rec := httptest.NewRecorder() srv.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestComponent_RouteGroup(t *testing.T) { srv := New(newLogger(), Config{}) if err := srv.OnInit(); err != nil { t.Fatal(err) } srv.Route("/api", func(r chi.Router) { r.Get("/v1/hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) }) req := httptest.NewRequest(http.MethodGet, "/api/v1/hello", nil) rec := httptest.NewRecorder() srv.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("want 200, got %d", rec.Code) } } func TestComponent_OnStart_OnStop(t *testing.T) { port := freePort(t) srv := New(newLogger(), Config{ Host: "127.0.0.1", Port: port, ReadTimeout: time.Second, WriteTimeout: time.Second, IdleTimeout: time.Second, }) if err := srv.OnInit(); err != nil { t.Fatal(err) } if err := srv.OnStart(); err != nil { t.Fatalf("OnStart: %v", err) } // Give the goroutine time to bind. time.Sleep(20 * time.Millisecond) if err := srv.OnStop(); err != nil { t.Errorf("OnStop: %v", err) } } func TestComponent_OnStop_Graceful(t *testing.T) { port := freePort(t) srv := New(newLogger(), Config{ Host: "127.0.0.1", Port: port, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, IdleTimeout: 5 * time.Second, }) if err := srv.OnInit(); err != nil { t.Fatal(err) } // Register a slow handler. done := make(chan struct{}) srv.Get("/slow", func(w http.ResponseWriter, r *http.Request) { <-done w.WriteHeader(http.StatusOK) }) if err := srv.OnStart(); err != nil { t.Fatal(err) } time.Sleep(20 * time.Millisecond) // Fire a request in the background. result := make(chan int, 1) go func() { resp, err := http.Get("http://127.0.0.1:" + itoa(port) + "/slow") if err != nil { result <- 0 return } result <- resp.StatusCode }() time.Sleep(20 * time.Millisecond) // Unblock handler, then stop the server. close(done) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() h := srv.(*httpServer) if err := h.srv.Shutdown(ctx); err != nil { t.Errorf("graceful shutdown error: %v", err) } if code := <-result; code != http.StatusOK { t.Errorf("in-flight request: want 200, got %d", code) } } func itoa(n int) string { return fmt.Sprint(n) }