package worker import ( "context" "errors" "sync" "sync/atomic" "testing" "time" "code.nochebuena.dev/go/logz" ) func newLogger() logz.Logger { return logz.New(logz.Options{}) } func startWorker(t *testing.T, cfg Config) Component { t.Helper() c := New(newLogger(), cfg) if err := c.OnInit(); err != nil { t.Fatalf("OnInit: %v", err) } if err := c.OnStart(); err != nil { t.Fatalf("OnStart: %v", err) } return c } func TestNew(t *testing.T) { if New(newLogger(), Config{}) == nil { t.Fatal("New returned nil") } } func TestWorker_DispatchAndExecute(t *testing.T) { done := make(chan struct{}) c := startWorker(t, Config{PoolSize: 1, BufferSize: 10, ShutdownTimeout: time.Second}) c.Dispatch(func(ctx context.Context) error { close(done) return nil }) select { case <-done: case <-time.After(time.Second): t.Fatal("task not executed in time") } c.OnStop() } func TestWorker_BackpressureFull(t *testing.T) { // Block the single worker so the queue fills up. block := make(chan struct{}) c := startWorker(t, Config{PoolSize: 1, BufferSize: 1, ShutdownTimeout: time.Second}) c.Dispatch(func(ctx context.Context) error { <-block; return nil }) // fills worker c.Dispatch(func(ctx context.Context) error { return nil }) // fills buffer ok := c.Dispatch(func(ctx context.Context) error { return nil }) // should be dropped if ok { t.Error("expected Dispatch to return false when queue is full") } close(block) c.OnStop() } func TestWorker_OnStop_DrainsQueue(t *testing.T) { var count int64 c := startWorker(t, Config{PoolSize: 2, BufferSize: 50, ShutdownTimeout: 5 * time.Second}) for i := 0; i < 10; i++ { c.Dispatch(func(ctx context.Context) error { atomic.AddInt64(&count, 1) return nil }) } c.OnStop() if atomic.LoadInt64(&count) != 10 { t.Errorf("expected 10 tasks completed, got %d", count) } } func TestWorker_OnStop_Timeout(t *testing.T) { c := startWorker(t, Config{PoolSize: 1, BufferSize: 1, ShutdownTimeout: 50 * time.Millisecond}) // Dispatch a task that blocks longer than ShutdownTimeout. c.Dispatch(func(ctx context.Context) error { time.Sleep(500 * time.Millisecond) return nil }) start := time.Now() c.OnStop() elapsed := time.Since(start) // OnStop should return after ~ShutdownTimeout, not after 500ms. if elapsed > 300*time.Millisecond { t.Errorf("OnStop blocked too long: %v", elapsed) } } func TestWorker_TaskTimeout(t *testing.T) { var ctxCancelled int64 c := startWorker(t, Config{ PoolSize: 1, BufferSize: 10, TaskTimeout: 50 * time.Millisecond, ShutdownTimeout: time.Second, }) done := make(chan struct{}) c.Dispatch(func(ctx context.Context) error { defer close(done) select { case <-ctx.Done(): atomic.StoreInt64(&ctxCancelled, 1) case <-time.After(500 * time.Millisecond): } return nil }) select { case <-done: case <-time.After(time.Second): t.Fatal("task did not complete in time") } if atomic.LoadInt64(&ctxCancelled) != 1 { t.Error("expected task context to be cancelled by TaskTimeout") } c.OnStop() } func TestWorker_MultipleWorkers(t *testing.T) { const n = 5 var wg sync.WaitGroup wg.Add(n) started := make(chan struct{}, n) c := startWorker(t, Config{PoolSize: n, BufferSize: n, ShutdownTimeout: time.Second}) block := make(chan struct{}) for i := 0; i < n; i++ { c.Dispatch(func(ctx context.Context) error { started <- struct{}{} <-block wg.Done() return nil }) } // All n tasks should start concurrently. timer := time.NewTimer(time.Second) defer timer.Stop() for i := 0; i < n; i++ { select { case <-started: case <-timer.C: t.Fatalf("only %d/%d workers started concurrently", i, n) } } close(block) c.OnStop() } func TestWorker_TaskError(t *testing.T) { c := startWorker(t, Config{PoolSize: 1, BufferSize: 10, ShutdownTimeout: time.Second}) done := make(chan struct{}) c.Dispatch(func(ctx context.Context) error { defer close(done) return errors.New("task error") }) select { case <-done: case <-time.After(time.Second): t.Fatal("task did not run") } c.OnStop() // should not panic } func TestWorker_Lifecycle(t *testing.T) { c := New(newLogger(), Config{PoolSize: 2, BufferSize: 10, ShutdownTimeout: time.Second}) if err := c.OnInit(); err != nil { t.Fatalf("OnInit: %v", err) } if err := c.OnStart(); err != nil { t.Fatalf("OnStart: %v", err) } done := make(chan struct{}) c.Dispatch(func(ctx context.Context) error { close(done); return nil }) select { case <-done: case <-time.After(time.Second): t.Fatal("task not executed") } if err := c.OnStop(); err != nil { t.Fatalf("OnStop: %v", err) } }