package postgres import ( "context" "fmt" "net/url" "sync" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "code.nochebuena.dev/go/health" "code.nochebuena.dev/go/launcher" "code.nochebuena.dev/go/logz" ) // Executor is the shared query interface for pool and transaction. type Executor interface { Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row } // Tx extends Executor with commit/rollback. type Tx interface { Executor Commit(ctx context.Context) error Rollback(ctx context.Context) error } // Client is the database access interface. type Client interface { // GetExecutor returns the active transaction from ctx if one exists, // otherwise returns the pool. GetExecutor(ctx context.Context) Executor Begin(ctx context.Context) (Tx, error) Ping(ctx context.Context) error HandleError(err error) error } // Component bundles launcher lifecycle, health check, and database client. type Component interface { launcher.Component health.Checkable Client } // UnitOfWork wraps operations in a single database transaction. type UnitOfWork interface { Do(ctx context.Context, fn func(ctx context.Context) error) error } // Config holds PostgreSQL connection settings. type Config struct { Host string `env:"PG_HOST,required"` Port int `env:"PG_PORT" envDefault:"5432"` User string `env:"PG_USER,required"` Password string `env:"PG_PASSWORD,required"` Name string `env:"PG_NAME,required"` SSLMode string `env:"PG_SSL_MODE" envDefault:"disable"` Timezone string `env:"PG_TIMEZONE" envDefault:"UTC"` MaxConns int `env:"PG_MAX_CONNS" envDefault:"5"` MinConns int `env:"PG_MIN_CONNS" envDefault:"2"` MaxConnLifetime string `env:"PG_MAX_CONN_LIFETIME" envDefault:"1h"` MaxConnIdleTime string `env:"PG_MAX_CONN_IDLE_TIME" envDefault:"30m"` HealthCheckPeriod string `env:"PG_HEALTH_CHECK_PERIOD" envDefault:"1m"` } // DSN constructs a PostgreSQL connection string from the configuration. func (c Config) DSN() string { u := &url.URL{ Scheme: "postgres", User: url.UserPassword(c.User, c.Password), Host: fmt.Sprintf("%s:%d", c.Host, c.Port), Path: "/" + c.Name, } q := u.Query() q.Set("sslmode", c.SSLMode) q.Set("timezone", c.Timezone) u.RawQuery = q.Encode() return u.String() } // ctxTxKey is the context key for the active transaction. type ctxTxKey struct{} // --- pgComponent --- type pgComponent struct { logger logz.Logger cfg Config pool *pgxpool.Pool mu sync.RWMutex } // New returns a postgres Component. Call lc.Append(db) to manage its lifecycle. func New(logger logz.Logger, cfg Config) Component { return &pgComponent{logger: logger, cfg: cfg} } func (c *pgComponent) OnInit() error { poolCfg, err := c.buildPoolConfig() if err != nil { return fmt.Errorf("postgres: parse config: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() pool, err := pgxpool.NewWithConfig(ctx, poolCfg) if err != nil { c.logger.Error("postgres: failed to create pool", err) return fmt.Errorf("postgres: create pool: %w", err) } c.mu.Lock() c.pool = pool c.mu.Unlock() return nil } func (c *pgComponent) OnStart() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := c.Ping(ctx); err != nil { return fmt.Errorf("postgres: ping failed: %w", err) } c.logger.Info("postgres: connected") return nil } func (c *pgComponent) OnStop() error { c.mu.Lock() defer c.mu.Unlock() if c.pool != nil { c.logger.Info("postgres: closing pool") c.pool.Close() c.pool = nil } return nil } func (c *pgComponent) Ping(ctx context.Context) error { c.mu.RLock() pool := c.pool c.mu.RUnlock() if pool == nil { return fmt.Errorf("postgres: pool not initialized") } return pool.Ping(ctx) } func (c *pgComponent) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } c.mu.RLock() pool := c.pool c.mu.RUnlock() return pool } func (c *pgComponent) Begin(ctx context.Context) (Tx, error) { c.mu.RLock() pool := c.pool c.mu.RUnlock() if pool == nil { return nil, fmt.Errorf("postgres: pool not initialized") } tx, err := pool.Begin(ctx) if err != nil { return nil, err } return &pgTx{Tx: tx}, nil } func (c *pgComponent) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { c.mu.RLock() pool := c.pool c.mu.RUnlock() return pool.Exec(ctx, sql, args...) } func (c *pgComponent) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { c.mu.RLock() pool := c.pool c.mu.RUnlock() return pool.Query(ctx, sql, args...) } func (c *pgComponent) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { c.mu.RLock() pool := c.pool c.mu.RUnlock() return pool.QueryRow(ctx, sql, args...) } func (c *pgComponent) HandleError(err error) error { return HandleError(err) } // health.Checkable func (c *pgComponent) HealthCheck(ctx context.Context) error { return c.Ping(ctx) } func (c *pgComponent) Name() string { return "postgres" } func (c *pgComponent) Priority() health.Level { return health.LevelCritical } func (c *pgComponent) buildPoolConfig() (*pgxpool.Config, error) { cfg, err := pgxpool.ParseConfig(c.cfg.DSN()) if err != nil { return nil, err } cfg.MaxConns = int32(c.cfg.MaxConns) cfg.MinConns = int32(c.cfg.MinConns) if c.cfg.MaxConnLifetime != "" { d, err := time.ParseDuration(c.cfg.MaxConnLifetime) if err != nil { return nil, fmt.Errorf("PG_MAX_CONN_LIFETIME: %w", err) } cfg.MaxConnLifetime = d } if c.cfg.MaxConnIdleTime != "" { d, err := time.ParseDuration(c.cfg.MaxConnIdleTime) if err != nil { return nil, fmt.Errorf("PG_MAX_CONN_IDLE_TIME: %w", err) } cfg.MaxConnIdleTime = d } if c.cfg.HealthCheckPeriod != "" { d, err := time.ParseDuration(c.cfg.HealthCheckPeriod) if err != nil { return nil, fmt.Errorf("PG_HEALTH_CHECK_PERIOD: %w", err) } cfg.HealthCheckPeriod = d } return cfg, nil } // --- pgTx --- type pgTx struct{ pgx.Tx } func (t *pgTx) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { return t.Tx.Exec(ctx, sql, args...) } func (t *pgTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return t.Tx.Query(ctx, sql, args...) } func (t *pgTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return t.Tx.QueryRow(ctx, sql, args...) } func (t *pgTx) Commit(ctx context.Context) error { return t.Tx.Commit(ctx) } func (t *pgTx) Rollback(ctx context.Context) error { return t.Tx.Rollback(ctx) } // --- UnitOfWork --- type unitOfWork struct { logger logz.Logger client Client } // NewUnitOfWork returns a UnitOfWork backed by the given client. func NewUnitOfWork(logger logz.Logger, client Client) UnitOfWork { return &unitOfWork{logger: logger, client: client} } func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error { tx, err := u.client.Begin(ctx) if err != nil { return fmt.Errorf("postgres: begin transaction: %w", err) } ctx = context.WithValue(ctx, ctxTxKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(ctx); rbErr != nil { u.logger.Error("postgres: rollback failed", rbErr) } return err } return tx.Commit(ctx) }