package sqlite import ( "context" "database/sql" "fmt" "sync" "time" _ "modernc.org/sqlite" // register sqlite driver "code.nochebuena.dev/go/health" "code.nochebuena.dev/go/launcher" "code.nochebuena.dev/go/logz" ) // Executor defines operations shared by the connection and transaction. type Executor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } // Tx extends Executor with commit/rollback. // Honest contract: database/sql Tx does not accept ctx on Commit/Rollback. type Tx interface { Executor Commit() error Rollback() error } // Client is the primary interface for consumers. type Client interface { GetExecutor(ctx context.Context) Executor Begin(ctx context.Context) (Tx, error) Ping(ctx context.Context) error HandleError(err error) error } // Component bundles lifecycle + health + client. type Component interface { launcher.Component health.Checkable Client } // UnitOfWork manages the transaction lifecycle via context injection. type UnitOfWork interface { Do(ctx context.Context, fn func(ctx context.Context) error) error } // Config holds connection parameters. type Config struct { // Path is the SQLite file path. Use ":memory:" for in-memory databases. Path string `env:"SQLITE_PATH,required"` // MaxOpenConns limits concurrent connections. Default: 1. MaxOpenConns int `env:"SQLITE_MAX_OPEN_CONNS" envDefault:"1"` // MaxIdleConns is the number of idle connections kept in the pool. MaxIdleConns int `env:"SQLITE_MAX_IDLE_CONNS" envDefault:"1"` // Pragmas are appended to the DSN. Default: WAL + 5s busy timeout + FK enforcement. Pragmas string `env:"SQLITE_PRAGMAS" envDefault:"?_journal=WAL&_timeout=5000&_fk=true"` } func (c Config) dsn() string { return c.Path + c.Pragmas } // ctxTxKey is the context key for the active transaction. type ctxTxKey struct{} // --- sqliteComponent --- type sqliteComponent struct { logger logz.Logger cfg Config db *sql.DB mu sync.RWMutex writeMu sync.Mutex // serialises writes to prevent SQLITE_BUSY } // New returns a sqlite Component. Call lc.Append(db) to manage its lifecycle. func New(logger logz.Logger, cfg Config) Component { return &sqliteComponent{logger: logger, cfg: cfg} } func (c *sqliteComponent) OnInit() error { db, err := sql.Open("sqlite", c.cfg.dsn()) if err != nil { return fmt.Errorf("sqlite: open: %w", err) } maxOpen := c.cfg.MaxOpenConns if maxOpen == 0 { maxOpen = 1 } db.SetMaxOpenConns(maxOpen) db.SetMaxIdleConns(c.cfg.MaxIdleConns) // Enforce foreign keys per-connection (SQLite disables them by default). if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { _ = db.Close() return fmt.Errorf("sqlite: enable foreign keys: %w", err) } c.mu.Lock() c.db = db c.mu.Unlock() return nil } func (c *sqliteComponent) OnStart() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := c.Ping(ctx); err != nil { return fmt.Errorf("sqlite: ping failed: %w", err) } c.logger.Info("sqlite: ready", "path", c.cfg.Path) return nil } func (c *sqliteComponent) OnStop() error { c.mu.Lock() defer c.mu.Unlock() if c.db != nil { c.logger.Info("sqlite: closing") _ = c.db.Close() c.db = nil } return nil } func (c *sqliteComponent) Ping(ctx context.Context) error { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return fmt.Errorf("sqlite: not initialized") } return db.PingContext(ctx) } func (c *sqliteComponent) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } c.mu.RLock() db := c.db c.mu.RUnlock() return db } func (c *sqliteComponent) Begin(ctx context.Context) (Tx, error) { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil, fmt.Errorf("sqlite: not initialized") } tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } return &sqliteTx{Tx: tx}, nil } func (c *sqliteComponent) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.ExecContext(ctx, query, args...) } func (c *sqliteComponent) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryContext(ctx, query, args...) } func (c *sqliteComponent) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryRowContext(ctx, query, args...) } func (c *sqliteComponent) HandleError(err error) error { return HandleError(err) } // health.Checkable func (c *sqliteComponent) HealthCheck(ctx context.Context) error { return c.Ping(ctx) } func (c *sqliteComponent) Name() string { return "sqlite" } func (c *sqliteComponent) Priority() health.Level { return health.LevelCritical } // --- sqliteTx --- type sqliteTx struct{ *sql.Tx } func (t *sqliteTx) ExecContext(ctx context.Context, q string, args ...any) (sql.Result, error) { return t.Tx.ExecContext(ctx, q, args...) } func (t *sqliteTx) QueryContext(ctx context.Context, q string, args ...any) (*sql.Rows, error) { return t.Tx.QueryContext(ctx, q, args...) } func (t *sqliteTx) QueryRowContext(ctx context.Context, q string, args ...any) *sql.Row { return t.Tx.QueryRowContext(ctx, q, args...) } func (t *sqliteTx) Commit() error { return t.Tx.Commit() } func (t *sqliteTx) Rollback() error { return t.Tx.Rollback() } // --- UnitOfWork --- type unitOfWork struct { logger logz.Logger client Client writeMu *sync.Mutex } // NewUnitOfWork returns a UnitOfWork backed by the given client. // If client is a *sqliteComponent, the write mutex is used to serialise transactions. func NewUnitOfWork(logger logz.Logger, client Client) UnitOfWork { var mu *sync.Mutex if sc, ok := client.(*sqliteComponent); ok { mu = &sc.writeMu } return &unitOfWork{logger: logger, client: client, writeMu: mu} } func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error { if u.writeMu != nil { u.writeMu.Lock() defer u.writeMu.Unlock() } tx, err := u.client.Begin(ctx) if err != nil { return fmt.Errorf("sqlite: begin transaction: %w", err) } ctx = context.WithValue(ctx, ctxTxKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { u.logger.Error("sqlite: rollback failed", rbErr) } return err } return tx.Commit() }