202 lines
4.6 KiB
Go
202 lines
4.6 KiB
Go
|
|
package worker
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"sync"
|
||
|
|
"sync/atomic"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/go/logz"
|
||
|
|
)
|
||
|
|
|
||
|
|
func newLogger() logz.Logger { return logz.New(logz.Options{}) }
|
||
|
|
|
||
|
|
func startWorker(t *testing.T, cfg Config) Component {
|
||
|
|
t.Helper()
|
||
|
|
c := New(newLogger(), cfg)
|
||
|
|
if err := c.OnInit(); err != nil {
|
||
|
|
t.Fatalf("OnInit: %v", err)
|
||
|
|
}
|
||
|
|
if err := c.OnStart(); err != nil {
|
||
|
|
t.Fatalf("OnStart: %v", err)
|
||
|
|
}
|
||
|
|
return c
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestNew(t *testing.T) {
|
||
|
|
if New(newLogger(), Config{}) == nil {
|
||
|
|
t.Fatal("New returned nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_DispatchAndExecute(t *testing.T) {
|
||
|
|
done := make(chan struct{})
|
||
|
|
c := startWorker(t, Config{PoolSize: 1, BufferSize: 10, ShutdownTimeout: time.Second})
|
||
|
|
|
||
|
|
c.Dispatch(func(ctx context.Context) error {
|
||
|
|
close(done)
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
select {
|
||
|
|
case <-done:
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("task not executed in time")
|
||
|
|
}
|
||
|
|
c.OnStop()
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_BackpressureFull(t *testing.T) {
|
||
|
|
// Block the single worker so the queue fills up.
|
||
|
|
block := make(chan struct{})
|
||
|
|
c := startWorker(t, Config{PoolSize: 1, BufferSize: 1, ShutdownTimeout: time.Second})
|
||
|
|
|
||
|
|
c.Dispatch(func(ctx context.Context) error { <-block; return nil }) // fills worker
|
||
|
|
c.Dispatch(func(ctx context.Context) error { return nil }) // fills buffer
|
||
|
|
|
||
|
|
ok := c.Dispatch(func(ctx context.Context) error { return nil }) // should be dropped
|
||
|
|
if ok {
|
||
|
|
t.Error("expected Dispatch to return false when queue is full")
|
||
|
|
}
|
||
|
|
close(block)
|
||
|
|
c.OnStop()
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_OnStop_DrainsQueue(t *testing.T) {
|
||
|
|
var count int64
|
||
|
|
c := startWorker(t, Config{PoolSize: 2, BufferSize: 50, ShutdownTimeout: 5 * time.Second})
|
||
|
|
|
||
|
|
for i := 0; i < 10; i++ {
|
||
|
|
c.Dispatch(func(ctx context.Context) error {
|
||
|
|
atomic.AddInt64(&count, 1)
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
}
|
||
|
|
c.OnStop()
|
||
|
|
|
||
|
|
if atomic.LoadInt64(&count) != 10 {
|
||
|
|
t.Errorf("expected 10 tasks completed, got %d", count)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_OnStop_Timeout(t *testing.T) {
|
||
|
|
c := startWorker(t, Config{PoolSize: 1, BufferSize: 1,
|
||
|
|
ShutdownTimeout: 50 * time.Millisecond})
|
||
|
|
|
||
|
|
// Dispatch a task that blocks longer than ShutdownTimeout.
|
||
|
|
c.Dispatch(func(ctx context.Context) error {
|
||
|
|
time.Sleep(500 * time.Millisecond)
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
start := time.Now()
|
||
|
|
c.OnStop()
|
||
|
|
elapsed := time.Since(start)
|
||
|
|
|
||
|
|
// OnStop should return after ~ShutdownTimeout, not after 500ms.
|
||
|
|
if elapsed > 300*time.Millisecond {
|
||
|
|
t.Errorf("OnStop blocked too long: %v", elapsed)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_TaskTimeout(t *testing.T) {
|
||
|
|
var ctxCancelled int64
|
||
|
|
c := startWorker(t, Config{
|
||
|
|
PoolSize: 1, BufferSize: 10,
|
||
|
|
TaskTimeout: 50 * time.Millisecond,
|
||
|
|
ShutdownTimeout: time.Second,
|
||
|
|
})
|
||
|
|
|
||
|
|
done := make(chan struct{})
|
||
|
|
c.Dispatch(func(ctx context.Context) error {
|
||
|
|
defer close(done)
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
atomic.StoreInt64(&ctxCancelled, 1)
|
||
|
|
case <-time.After(500 * time.Millisecond):
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
|
||
|
|
select {
|
||
|
|
case <-done:
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("task did not complete in time")
|
||
|
|
}
|
||
|
|
if atomic.LoadInt64(&ctxCancelled) != 1 {
|
||
|
|
t.Error("expected task context to be cancelled by TaskTimeout")
|
||
|
|
}
|
||
|
|
c.OnStop()
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_MultipleWorkers(t *testing.T) {
|
||
|
|
const n = 5
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
wg.Add(n)
|
||
|
|
started := make(chan struct{}, n)
|
||
|
|
|
||
|
|
c := startWorker(t, Config{PoolSize: n, BufferSize: n, ShutdownTimeout: time.Second})
|
||
|
|
|
||
|
|
block := make(chan struct{})
|
||
|
|
for i := 0; i < n; i++ {
|
||
|
|
c.Dispatch(func(ctx context.Context) error {
|
||
|
|
started <- struct{}{}
|
||
|
|
<-block
|
||
|
|
wg.Done()
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// All n tasks should start concurrently.
|
||
|
|
timer := time.NewTimer(time.Second)
|
||
|
|
defer timer.Stop()
|
||
|
|
for i := 0; i < n; i++ {
|
||
|
|
select {
|
||
|
|
case <-started:
|
||
|
|
case <-timer.C:
|
||
|
|
t.Fatalf("only %d/%d workers started concurrently", i, n)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
close(block)
|
||
|
|
c.OnStop()
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_TaskError(t *testing.T) {
|
||
|
|
c := startWorker(t, Config{PoolSize: 1, BufferSize: 10, ShutdownTimeout: time.Second})
|
||
|
|
|
||
|
|
done := make(chan struct{})
|
||
|
|
c.Dispatch(func(ctx context.Context) error {
|
||
|
|
defer close(done)
|
||
|
|
return errors.New("task error")
|
||
|
|
})
|
||
|
|
|
||
|
|
select {
|
||
|
|
case <-done:
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("task did not run")
|
||
|
|
}
|
||
|
|
c.OnStop() // should not panic
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWorker_Lifecycle(t *testing.T) {
|
||
|
|
c := New(newLogger(), Config{PoolSize: 2, BufferSize: 10, ShutdownTimeout: time.Second})
|
||
|
|
if err := c.OnInit(); err != nil {
|
||
|
|
t.Fatalf("OnInit: %v", err)
|
||
|
|
}
|
||
|
|
if err := c.OnStart(); err != nil {
|
||
|
|
t.Fatalf("OnStart: %v", err)
|
||
|
|
}
|
||
|
|
done := make(chan struct{})
|
||
|
|
c.Dispatch(func(ctx context.Context) error { close(done); return nil })
|
||
|
|
select {
|
||
|
|
case <-done:
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("task not executed")
|
||
|
|
}
|
||
|
|
if err := c.OnStop(); err != nil {
|
||
|
|
t.Fatalf("OnStop: %v", err)
|
||
|
|
}
|
||
|
|
}
|