131 lines
3.1 KiB
Go
131 lines
3.1 KiB
Go
|
|
package worker
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/go/launcher"
|
||
|
|
"code.nochebuena.dev/go/logz"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Task is a unit of work executed asynchronously by the worker pool.
|
||
|
|
type Task func(ctx context.Context) error
|
||
|
|
|
||
|
|
// Provider dispatches tasks to the pool.
|
||
|
|
type Provider interface {
|
||
|
|
// Dispatch queues a task. Returns false if the queue is full (backpressure).
|
||
|
|
Dispatch(task Task) bool
|
||
|
|
}
|
||
|
|
|
||
|
|
// Component adds lifecycle management to Provider.
|
||
|
|
type Component interface {
|
||
|
|
launcher.Component
|
||
|
|
Provider
|
||
|
|
}
|
||
|
|
|
||
|
|
// Config holds worker pool settings.
|
||
|
|
type Config struct {
|
||
|
|
// PoolSize is the number of concurrent workers. Default: 5.
|
||
|
|
PoolSize int `env:"WORKER_POOL_SIZE" envDefault:"5"`
|
||
|
|
// BufferSize is the task queue capacity. Default: 100.
|
||
|
|
BufferSize int `env:"WORKER_BUFFER_SIZE" envDefault:"100"`
|
||
|
|
// TaskTimeout is the maximum duration for a single task. Zero = no timeout.
|
||
|
|
TaskTimeout time.Duration `env:"WORKER_TASK_TIMEOUT" envDefault:"0s"`
|
||
|
|
// ShutdownTimeout is how long OnStop waits for workers to drain. Default: 30s.
|
||
|
|
ShutdownTimeout time.Duration `env:"WORKER_SHUTDOWN_TIMEOUT" envDefault:"30s"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type workerComponent struct {
|
||
|
|
logger logz.Logger
|
||
|
|
cfg Config
|
||
|
|
taskQueue chan Task
|
||
|
|
wg sync.WaitGroup
|
||
|
|
ctx context.Context
|
||
|
|
cancel context.CancelFunc
|
||
|
|
}
|
||
|
|
|
||
|
|
// New returns a worker Component. Call lc.Append(pool) to manage its lifecycle.
|
||
|
|
func New(logger logz.Logger, cfg Config) Component {
|
||
|
|
if cfg.PoolSize <= 0 {
|
||
|
|
cfg.PoolSize = 5
|
||
|
|
}
|
||
|
|
if cfg.BufferSize <= 0 {
|
||
|
|
cfg.BufferSize = 100
|
||
|
|
}
|
||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
return &workerComponent{
|
||
|
|
logger: logger,
|
||
|
|
cfg: cfg,
|
||
|
|
taskQueue: make(chan Task, cfg.BufferSize),
|
||
|
|
ctx: ctx,
|
||
|
|
cancel: cancel,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (w *workerComponent) OnInit() error {
|
||
|
|
w.logger.Info("worker: initializing pool",
|
||
|
|
"pool_size", w.cfg.PoolSize,
|
||
|
|
"buffer_size", w.cfg.BufferSize)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (w *workerComponent) OnStart() error {
|
||
|
|
w.logger.Info("worker: starting workers")
|
||
|
|
for i := 0; i < w.cfg.PoolSize; i++ {
|
||
|
|
w.wg.Add(1)
|
||
|
|
go func(id int) {
|
||
|
|
defer w.wg.Done()
|
||
|
|
w.runWorker(id)
|
||
|
|
}(i)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (w *workerComponent) OnStop() error {
|
||
|
|
w.logger.Info("worker: stopping, draining queue")
|
||
|
|
close(w.taskQueue)
|
||
|
|
w.cancel()
|
||
|
|
|
||
|
|
done := make(chan struct{})
|
||
|
|
go func() { w.wg.Wait(); close(done) }()
|
||
|
|
|
||
|
|
timeout := w.cfg.ShutdownTimeout
|
||
|
|
if timeout == 0 {
|
||
|
|
timeout = 30 * time.Second
|
||
|
|
}
|
||
|
|
select {
|
||
|
|
case <-done:
|
||
|
|
w.logger.Info("worker: all workers stopped cleanly")
|
||
|
|
case <-time.After(timeout):
|
||
|
|
w.logger.Error("worker: shutdown timeout reached; some workers may still be running", nil)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (w *workerComponent) Dispatch(task Task) bool {
|
||
|
|
select {
|
||
|
|
case w.taskQueue <- task:
|
||
|
|
return true
|
||
|
|
default:
|
||
|
|
w.logger.Error("worker: queue full, task dropped", nil)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (w *workerComponent) runWorker(id int) {
|
||
|
|
for task := range w.taskQueue {
|
||
|
|
ctx := w.ctx
|
||
|
|
var cancel context.CancelFunc
|
||
|
|
if w.cfg.TaskTimeout > 0 {
|
||
|
|
ctx, cancel = context.WithTimeout(ctx, w.cfg.TaskTimeout)
|
||
|
|
} else {
|
||
|
|
cancel = func() {}
|
||
|
|
}
|
||
|
|
if err := task(ctx); err != nil {
|
||
|
|
w.logger.Error("worker: task failed", err, "worker_id", id)
|
||
|
|
}
|
||
|
|
cancel()
|
||
|
|
}
|
||
|
|
}
|