204 lines
5.3 KiB
Go
204 lines
5.3 KiB
Go
|
|
package sqlite
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"database/sql"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
_ "modernc.org/sqlite" // register sqlite driver
|
||
|
|
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/logging"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/observability"
|
||
|
|
"code.nochebuena.dev/einherjar/core/xerrors"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Compile-time interface verification (I-8 / CT-5).
|
||
|
|
var _ Component = (*sqliteImpl)(nil)
|
||
|
|
var _ observability.Identifiable = (*sqliteImpl)(nil)
|
||
|
|
var _ Tx = (*sqliteTx)(nil)
|
||
|
|
var _ UnitOfWork = (*unitOfWork)(nil)
|
||
|
|
|
||
|
|
// New returns a Component backed by the given configuration.
|
||
|
|
// The database is not opened until OnInit is called.
|
||
|
|
func New(logger logging.Logger, cfg Config) Component {
|
||
|
|
return &sqliteImpl{logger: logger, cfg: cfg}
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewUnitOfWork returns a UnitOfWork backed by the given client.
|
||
|
|
// When client is the result of [New], write transactions are serialized
|
||
|
|
// through an internal mutex to prevent SQLITE_BUSY errors.
|
||
|
|
func NewUnitOfWork(logger logging.Logger, client Provider) UnitOfWork {
|
||
|
|
var mu *sync.Mutex
|
||
|
|
if sc, ok := client.(*sqliteImpl); ok {
|
||
|
|
mu = &sc.writeMu
|
||
|
|
}
|
||
|
|
return &unitOfWork{logger: logger, client: client, writeMu: mu}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ctxTxKey is the context key for the active transaction.
|
||
|
|
type ctxTxKey struct{}
|
||
|
|
|
||
|
|
// --- sqliteImpl ---
|
||
|
|
|
||
|
|
type sqliteImpl struct {
|
||
|
|
logger logging.Logger
|
||
|
|
cfg Config
|
||
|
|
db *sql.DB
|
||
|
|
mu sync.RWMutex
|
||
|
|
writeMu sync.Mutex // serializes writes to prevent SQLITE_BUSY
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) OnInit() error {
|
||
|
|
db, err := sql.Open("sqlite", c.cfg.DSN())
|
||
|
|
if err != nil {
|
||
|
|
return xerrors.New(xerrors.ErrInternal, "sqlite: open").WithError(err)
|
||
|
|
}
|
||
|
|
maxOpen := c.cfg.MaxOpenConns
|
||
|
|
if maxOpen == 0 {
|
||
|
|
maxOpen = 1
|
||
|
|
}
|
||
|
|
db.SetMaxOpenConns(maxOpen)
|
||
|
|
db.SetMaxIdleConns(c.cfg.MaxIdleConns)
|
||
|
|
// Enforce foreign keys per-connection (SQLite disables them by default).
|
||
|
|
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||
|
|
_ = db.Close()
|
||
|
|
return xerrors.New(xerrors.ErrInternal, "sqlite: enable foreign keys").WithError(err)
|
||
|
|
}
|
||
|
|
c.mu.Lock()
|
||
|
|
c.db = db
|
||
|
|
c.mu.Unlock()
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) OnStart() error {
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
if err := c.Ping(ctx); err != nil {
|
||
|
|
return xerrors.New(xerrors.ErrUnavailable, "sqlite: ping failed").WithError(err)
|
||
|
|
}
|
||
|
|
c.logger.Info("sqlite: connected", "path", c.cfg.Path)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) OnStop() error {
|
||
|
|
c.mu.Lock()
|
||
|
|
defer c.mu.Unlock()
|
||
|
|
if c.db != nil {
|
||
|
|
c.logger.Info("sqlite: closing")
|
||
|
|
_ = c.db.Close()
|
||
|
|
c.db = nil
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) Ping(ctx context.Context) error {
|
||
|
|
c.mu.RLock()
|
||
|
|
db := c.db
|
||
|
|
c.mu.RUnlock()
|
||
|
|
if db == nil {
|
||
|
|
return xerrors.New(xerrors.ErrInternal, "sqlite: not initialized")
|
||
|
|
}
|
||
|
|
return db.PingContext(ctx)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) GetExecutor(ctx context.Context) Executor {
|
||
|
|
if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok {
|
||
|
|
return tx
|
||
|
|
}
|
||
|
|
c.mu.RLock()
|
||
|
|
db := c.db
|
||
|
|
c.mu.RUnlock()
|
||
|
|
if db == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return db
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) Begin(ctx context.Context) (Tx, error) {
|
||
|
|
c.mu.RLock()
|
||
|
|
db := c.db
|
||
|
|
c.mu.RUnlock()
|
||
|
|
if db == nil {
|
||
|
|
return nil, xerrors.New(xerrors.ErrInternal, "sqlite: not initialized")
|
||
|
|
}
|
||
|
|
tx, err := db.BeginTx(ctx, nil)
|
||
|
|
if err != nil {
|
||
|
|
return nil, xerrors.New(xerrors.ErrInternal, "sqlite: begin transaction").WithError(err)
|
||
|
|
}
|
||
|
|
return &sqliteTx{Tx: tx}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||
|
|
c.mu.RLock()
|
||
|
|
db := c.db
|
||
|
|
c.mu.RUnlock()
|
||
|
|
return db.ExecContext(ctx, query, args...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||
|
|
c.mu.RLock()
|
||
|
|
db := c.db
|
||
|
|
c.mu.RUnlock()
|
||
|
|
return db.QueryContext(ctx, query, args...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
||
|
|
c.mu.RLock()
|
||
|
|
db := c.db
|
||
|
|
c.mu.RUnlock()
|
||
|
|
return db.QueryRowContext(ctx, query, args...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *sqliteImpl) HandleError(err error) error { return HandleError(err) }
|
||
|
|
|
||
|
|
func (c *sqliteImpl) HealthCheck(ctx context.Context) error { return c.Ping(ctx) }
|
||
|
|
func (c *sqliteImpl) Name() string { return "sqlite" }
|
||
|
|
func (c *sqliteImpl) Priority() observability.Level { return observability.LevelCritical }
|
||
|
|
|
||
|
|
// --- sqliteTx ---
|
||
|
|
|
||
|
|
type sqliteTx struct{ *sql.Tx }
|
||
|
|
|
||
|
|
func (t *sqliteTx) ExecContext(ctx context.Context, q string, args ...any) (sql.Result, error) {
|
||
|
|
return t.Tx.ExecContext(ctx, q, args...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *sqliteTx) QueryContext(ctx context.Context, q string, args ...any) (*sql.Rows, error) {
|
||
|
|
return t.Tx.QueryContext(ctx, q, args...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *sqliteTx) QueryRowContext(ctx context.Context, q string, args ...any) *sql.Row {
|
||
|
|
return t.Tx.QueryRowContext(ctx, q, args...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *sqliteTx) Commit() error { return t.Tx.Commit() }
|
||
|
|
func (t *sqliteTx) Rollback() error { return t.Tx.Rollback() }
|
||
|
|
|
||
|
|
// --- unitOfWork ---
|
||
|
|
|
||
|
|
type unitOfWork struct {
|
||
|
|
logger logging.Logger
|
||
|
|
client Provider
|
||
|
|
writeMu *sync.Mutex
|
||
|
|
}
|
||
|
|
|
||
|
|
func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error {
|
||
|
|
if u.writeMu != nil {
|
||
|
|
u.writeMu.Lock()
|
||
|
|
defer u.writeMu.Unlock()
|
||
|
|
}
|
||
|
|
tx, err := u.client.Begin(ctx)
|
||
|
|
if err != nil {
|
||
|
|
return xerrors.New(xerrors.ErrInternal, "sqlite: begin transaction").WithError(err)
|
||
|
|
}
|
||
|
|
ctx = context.WithValue(ctx, ctxTxKey{}, tx)
|
||
|
|
if err := fn(ctx); err != nil {
|
||
|
|
if rbErr := tx.Rollback(); rbErr != nil {
|
||
|
|
u.logger.Error("sqlite: rollback failed", rbErr)
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return tx.Commit()
|
||
|
|
}
|