package mysql import ( "context" "database/sql" "sync" "time" _ "github.com/go-sql-driver/mysql" // register driver "code.nochebuena.dev/einherjar/contracts/logging" "code.nochebuena.dev/einherjar/contracts/observability" "code.nochebuena.dev/einherjar/core/xerrors" ) // Compile-time interface verification (I-8 / CT-5). var _ Component = (*mysqlImpl)(nil) var _ observability.Identifiable = (*mysqlImpl)(nil) var _ Tx = (*mysqlTx)(nil) var _ UnitOfWork = (*unitOfWork)(nil) // New returns a Component backed by the given configuration. // The connection pool is not created until OnInit is called. func New(logger logging.Logger, cfg Config) Component { return &mysqlImpl{logger: logger, cfg: cfg} } // NewUnitOfWork returns a UnitOfWork backed by the given client. func NewUnitOfWork(logger logging.Logger, client Provider) UnitOfWork { return &unitOfWork{logger: logger, client: client} } // ctxTxKey is the context key for the active transaction. type ctxTxKey struct{} // --- mysqlImpl --- type mysqlImpl struct { logger logging.Logger cfg Config db *sql.DB mu sync.RWMutex } func (c *mysqlImpl) OnInit() error { db, err := sql.Open("mysql", c.cfg.DSN()) if err != nil { return xerrors.New(xerrors.ErrInvalidInput, "mysql: open").WithError(err) } db.SetMaxOpenConns(c.cfg.MaxConns) db.SetMaxIdleConns(c.cfg.MinConns) if c.cfg.MaxConnLifetime != "" { d, err := time.ParseDuration(c.cfg.MaxConnLifetime) if err != nil { return xerrors.New(xerrors.ErrInvalidInput, "mysql: parse EINHERJAR_MYSQL_MAX_CONN_LIFETIME").WithError(err) } db.SetConnMaxLifetime(d) } if c.cfg.MaxConnIdleTime != "" { d, err := time.ParseDuration(c.cfg.MaxConnIdleTime) if err != nil { return xerrors.New(xerrors.ErrInvalidInput, "mysql: parse EINHERJAR_MYSQL_MAX_CONN_IDLE_TIME").WithError(err) } db.SetConnMaxIdleTime(d) } c.mu.Lock() c.db = db c.mu.Unlock() return nil } func (c *mysqlImpl) OnStart() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := c.Ping(ctx); err != nil { return xerrors.New(xerrors.ErrUnavailable, "mysql: ping failed").WithError(err) } c.logger.Info("mysql: connected") return nil } func (c *mysqlImpl) OnStop() error { c.mu.Lock() defer c.mu.Unlock() if c.db != nil { c.logger.Info("mysql: closing pool") _ = c.db.Close() c.db = nil } return nil } func (c *mysqlImpl) Name() string { return "mysql" } func (c *mysqlImpl) Priority() observability.Level { return observability.LevelCritical } func (c *mysqlImpl) HealthCheck(ctx context.Context) error { return c.Ping(ctx) } func (c *mysqlImpl) Ping(ctx context.Context) error { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return xerrors.New(xerrors.ErrInternal, "mysql: not initialized") } return db.PingContext(ctx) } func (c *mysqlImpl) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil } return db } func (c *mysqlImpl) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil, xerrors.New(xerrors.ErrInternal, "mysql: not initialized") } tx, err := db.BeginTx(ctx, opts) if err != nil { return nil, xerrors.New(xerrors.ErrInternal, "mysql: begin transaction").WithError(err) } return &mysqlTx{Tx: tx}, nil } func (c *mysqlImpl) Begin(ctx context.Context) (Tx, error) { return c.BeginTx(ctx, nil) } func (c *mysqlImpl) Stats() sql.DBStats { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return sql.DBStats{} } return db.Stats() } func (c *mysqlImpl) HandleError(err error) error { return HandleError(err) } // --- mysqlTx --- type mysqlTx struct{ *sql.Tx } func (t *mysqlTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { return t.Tx.ExecContext(ctx, query, args...) } func (t *mysqlTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return t.Tx.QueryContext(ctx, query, args...) } func (t *mysqlTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { return t.Tx.QueryRowContext(ctx, query, args...) } func (t *mysqlTx) Commit() error { return t.Tx.Commit() } func (t *mysqlTx) Rollback() error { return t.Tx.Rollback() } // --- unitOfWork --- type unitOfWork struct { logger logging.Logger client Provider } func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error { tx, err := u.client.Begin(ctx) if err != nil { return xerrors.New(xerrors.ErrInternal, "mysql: begin transaction").WithError(err) } ctx = context.WithValue(ctx, ctxTxKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { u.logger.Error("mysql: rollback failed", rbErr) } return err } return tx.Commit() }