Files
httpserver/httpserver_test.go

193 lines
4.5 KiB
Go
Raw Permalink Normal View History

package httpserver
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
)
// --- helpers ---
type testLogger struct{ last string }
func (l *testLogger) Info(msg string, args ...any) { l.last = "info:" + msg }
func (l *testLogger) Error(msg string, err error, args ...any) { l.last = "error:" + msg }
func newLogger() *testLogger { return &testLogger{} }
// freePort returns a random available TCP port.
func freePort(t *testing.T) int {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := ln.Addr().(*net.TCPAddr).Port
ln.Close()
return port
}
// --- Tests ---
func TestNew(t *testing.T) {
srv := New(newLogger(), Config{})
if srv == nil {
t.Fatal("New returned nil")
}
}
func TestNew_ImplementsLauncherComponent(t *testing.T) {
srv := New(newLogger(), Config{})
// Verify launcher.Component methods exist (compile-time checked in compliance_test.go)
if err := srv.OnInit(); err != nil {
t.Fatalf("OnInit: %v", err)
}
}
func TestNew_WithMiddleware(t *testing.T) {
called := false
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
srv := New(newLogger(), Config{}, WithMiddleware(mw))
if err := srv.OnInit(); err != nil {
t.Fatalf("OnInit: %v", err)
}
srv.Get("/test", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
srv.ServeHTTP(rec, req)
if !called {
t.Error("WithMiddleware: middleware was not called")
}
}
func TestComponent_OnInit(t *testing.T) {
srv := New(newLogger(), Config{})
if err := srv.OnInit(); err != nil {
t.Errorf("OnInit returned error: %v", err)
}
}
func TestComponent_Routes(t *testing.T) {
srv := New(newLogger(), Config{})
if err := srv.OnInit(); err != nil {
t.Fatal(err)
}
srv.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
rec := httptest.NewRecorder()
srv.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestComponent_RouteGroup(t *testing.T) {
srv := New(newLogger(), Config{})
if err := srv.OnInit(); err != nil {
t.Fatal(err)
}
srv.Route("/api", func(r chi.Router) {
r.Get("/v1/hello", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
})
req := httptest.NewRequest(http.MethodGet, "/api/v1/hello", nil)
rec := httptest.NewRecorder()
srv.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("want 200, got %d", rec.Code)
}
}
func TestComponent_OnStart_OnStop(t *testing.T) {
port := freePort(t)
srv := New(newLogger(), Config{
Host: "127.0.0.1",
Port: port,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
IdleTimeout: time.Second,
})
if err := srv.OnInit(); err != nil {
t.Fatal(err)
}
if err := srv.OnStart(); err != nil {
t.Fatalf("OnStart: %v", err)
}
// Give the goroutine time to bind.
time.Sleep(20 * time.Millisecond)
if err := srv.OnStop(); err != nil {
t.Errorf("OnStop: %v", err)
}
}
func TestComponent_OnStop_Graceful(t *testing.T) {
port := freePort(t)
srv := New(newLogger(), Config{
Host: "127.0.0.1",
Port: port,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 5 * time.Second,
})
if err := srv.OnInit(); err != nil {
t.Fatal(err)
}
// Register a slow handler.
done := make(chan struct{})
srv.Get("/slow", func(w http.ResponseWriter, r *http.Request) {
<-done
w.WriteHeader(http.StatusOK)
})
if err := srv.OnStart(); err != nil {
t.Fatal(err)
}
time.Sleep(20 * time.Millisecond)
// Fire a request in the background.
result := make(chan int, 1)
go func() {
resp, err := http.Get("http://127.0.0.1:" + itoa(port) + "/slow")
if err != nil {
result <- 0
return
}
result <- resp.StatusCode
}()
time.Sleep(20 * time.Millisecond)
// Unblock handler, then stop the server.
close(done)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
h := srv.(*httpServer)
if err := h.srv.Shutdown(ctx); err != nil {
t.Errorf("graceful shutdown error: %v", err)
}
if code := <-result; code != http.StatusOK {
t.Errorf("in-flight request: want 200, got %d", code)
}
}
func itoa(n int) string {
return fmt.Sprint(n)
}