package sqlite import ( "context" "database/sql" "sync" "time" _ "modernc.org/sqlite" // register sqlite driver "code.nochebuena.dev/einherjar/contracts/logging" "code.nochebuena.dev/einherjar/contracts/observability" "code.nochebuena.dev/einherjar/core/xerrors" ) // Compile-time interface verification (I-8 / CT-5). var _ Component = (*sqliteImpl)(nil) var _ observability.Identifiable = (*sqliteImpl)(nil) var _ Tx = (*sqliteTx)(nil) var _ UnitOfWork = (*unitOfWork)(nil) // New returns a Component backed by the given configuration. // The database is not opened until OnInit is called. func New(logger logging.Logger, cfg Config) Component { return &sqliteImpl{logger: logger, cfg: cfg} } // NewUnitOfWork returns a UnitOfWork backed by the given client. // When client is the result of [New], write transactions are serialized // through an internal mutex to prevent SQLITE_BUSY errors. func NewUnitOfWork(logger logging.Logger, client Provider) UnitOfWork { var mu *sync.Mutex if sc, ok := client.(*sqliteImpl); ok { mu = &sc.writeMu } return &unitOfWork{logger: logger, client: client, writeMu: mu} } // ctxTxKey is the context key for the active transaction. type ctxTxKey struct{} // --- sqliteImpl --- type sqliteImpl struct { logger logging.Logger cfg Config db *sql.DB mu sync.RWMutex writeMu sync.Mutex // serializes writes to prevent SQLITE_BUSY } func (c *sqliteImpl) OnInit() error { db, err := sql.Open("sqlite", c.cfg.DSN()) if err != nil { return xerrors.New(xerrors.ErrInternal, "sqlite: open").WithError(err) } maxOpen := c.cfg.MaxOpenConns if maxOpen == 0 { maxOpen = 1 } db.SetMaxOpenConns(maxOpen) db.SetMaxIdleConns(c.cfg.MaxIdleConns) // Enforce foreign keys per-connection (SQLite disables them by default). if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { _ = db.Close() return xerrors.New(xerrors.ErrInternal, "sqlite: enable foreign keys").WithError(err) } c.mu.Lock() c.db = db c.mu.Unlock() return nil } func (c *sqliteImpl) OnStart() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := c.Ping(ctx); err != nil { return xerrors.New(xerrors.ErrUnavailable, "sqlite: ping failed").WithError(err) } c.logger.Info("sqlite: connected", "path", c.cfg.Path) return nil } func (c *sqliteImpl) OnStop() error { c.mu.Lock() defer c.mu.Unlock() if c.db != nil { c.logger.Info("sqlite: closing") _ = c.db.Close() c.db = nil } return nil } func (c *sqliteImpl) Ping(ctx context.Context) error { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return xerrors.New(xerrors.ErrInternal, "sqlite: not initialized") } return db.PingContext(ctx) } func (c *sqliteImpl) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil } return db } func (c *sqliteImpl) Begin(ctx context.Context) (Tx, error) { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil, xerrors.New(xerrors.ErrInternal, "sqlite: not initialized") } tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, xerrors.New(xerrors.ErrInternal, "sqlite: begin transaction").WithError(err) } return &sqliteTx{Tx: tx}, nil } func (c *sqliteImpl) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.ExecContext(ctx, query, args...) } func (c *sqliteImpl) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryContext(ctx, query, args...) } func (c *sqliteImpl) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryRowContext(ctx, query, args...) } func (c *sqliteImpl) HandleError(err error) error { return HandleError(err) } func (c *sqliteImpl) HealthCheck(ctx context.Context) error { return c.Ping(ctx) } func (c *sqliteImpl) Name() string { return "sqlite" } func (c *sqliteImpl) Priority() observability.Level { return observability.LevelCritical } // --- sqliteTx --- type sqliteTx struct{ *sql.Tx } func (t *sqliteTx) ExecContext(ctx context.Context, q string, args ...any) (sql.Result, error) { return t.Tx.ExecContext(ctx, q, args...) } func (t *sqliteTx) QueryContext(ctx context.Context, q string, args ...any) (*sql.Rows, error) { return t.Tx.QueryContext(ctx, q, args...) } func (t *sqliteTx) QueryRowContext(ctx context.Context, q string, args ...any) *sql.Row { return t.Tx.QueryRowContext(ctx, q, args...) } func (t *sqliteTx) Commit() error { return t.Tx.Commit() } func (t *sqliteTx) Rollback() error { return t.Tx.Rollback() } // --- unitOfWork --- type unitOfWork struct { logger logging.Logger client Provider writeMu *sync.Mutex } func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error { if u.writeMu != nil { u.writeMu.Lock() defer u.writeMu.Unlock() } tx, err := u.client.Begin(ctx) if err != nil { return xerrors.New(xerrors.ErrInternal, "sqlite: begin transaction").WithError(err) } ctx = context.WithValue(ctx, ctxTxKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { u.logger.Error("sqlite: rollback failed", rbErr) } return err } return tx.Commit() }