package mysql import ( "context" "database/sql" "fmt" "net/url" "sync" "time" _ "github.com/go-sql-driver/mysql" // register driver "code.nochebuena.dev/go/health" "code.nochebuena.dev/go/launcher" "code.nochebuena.dev/go/logz" ) // Executor defines operations shared by pool and transaction. type Executor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } // Tx extends Executor with commit/rollback (no ctx — sql.Tx limitation). type Tx interface { Executor Commit() error Rollback() error } // Client is the database access interface. type Client interface { GetExecutor(ctx context.Context) Executor Begin(ctx context.Context) (Tx, error) Ping(ctx context.Context) error HandleError(err error) error } // Component bundles launcher lifecycle, health check, and database client. type Component interface { launcher.Component health.Checkable Client } // UnitOfWork wraps operations in a single database transaction. type UnitOfWork interface { Do(ctx context.Context, fn func(ctx context.Context) error) error } // Config holds MySQL connection settings. // // DSN parameters Charset, Loc, and ParseTime default to "utf8mb4", "UTC", and // "true" respectively when left empty, preserving the behaviour of v0.9.0. // Set them explicitly when you need non-default values (e.g. Loc="Local"). // // Note on Collation: go-sql-driver v1.8.x negotiates the connection collation // via a 1-byte handshake ID (max 255). MariaDB 11.4+ collations such as // utf8mb4_uca1400_as_cs carry IDs > 255 and cannot be set through the DSN // collation parameter. Set the desired collation at the database/table level // in your schema migrations instead. type Config struct { Host string `env:"MYSQL_HOST,required"` Port int `env:"MYSQL_PORT" envDefault:"3306"` User string `env:"MYSQL_USER,required"` Password string `env:"MYSQL_PASSWORD,required"` Name string `env:"MYSQL_NAME,required"` MaxConns int `env:"MYSQL_MAX_CONNS" envDefault:"5"` MinConns int `env:"MYSQL_MIN_CONNS" envDefault:"2"` MaxConnLifetime string `env:"MYSQL_MAX_CONN_LIFETIME" envDefault:"1h"` MaxConnIdleTime string `env:"MYSQL_MAX_CONN_IDLE_TIME" envDefault:"30m"` // Charset is the connection character set sent as SET NAMES . // Defaults to "utf8mb4" when empty. Charset string `env:"MYSQL_CHARSET" envDefault:"utf8mb4"` // Loc is the IANA timezone name used for time.Time ↔ MySQL DATETIME // conversion. Defaults to "UTC" when empty. Loc string `env:"MYSQL_LOC" envDefault:"UTC"` // ParseTime controls whether the driver maps DATE/DATETIME columns to // time.Time. Valid values: "true", "false". Defaults to "true" when empty. ParseTime string `env:"MYSQL_PARSE_TIME" envDefault:"true"` } // DSN constructs a MySQL DSN from the configuration. // Empty Charset, Loc, and ParseTime fields fall back to their safe defaults // ("utf8mb4", "UTC", "true"), matching the behaviour of v0.9.0. func (c Config) DSN() string { charset := c.Charset if charset == "" { charset = "utf8mb4" } loc := c.Loc if loc == "" { loc = "UTC" } parseTime := c.ParseTime if parseTime == "" { parseTime = "true" } q := url.Values{} q.Set("charset", charset) q.Set("loc", loc) q.Set("parseTime", parseTime) // go-sql-driver uses user:pass@tcp(host:port)/db?params return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", c.User, c.Password, c.Host, c.Port, c.Name, q.Encode()) } // ctxTxKey is the context key for the active transaction. type ctxTxKey struct{} // --- mysqlComponent --- type mysqlComponent struct { logger logz.Logger cfg Config db *sql.DB mu sync.RWMutex } // New returns a mysql Component. Call lc.Append(db) to manage its lifecycle. func New(logger logz.Logger, cfg Config) Component { return &mysqlComponent{logger: logger, cfg: cfg} } func (c *mysqlComponent) OnInit() error { db, err := sql.Open("mysql", c.cfg.DSN()) if err != nil { return fmt.Errorf("mysql: open: %w", err) } db.SetMaxOpenConns(c.cfg.MaxConns) db.SetMaxIdleConns(c.cfg.MinConns) if c.cfg.MaxConnLifetime != "" { d, err := time.ParseDuration(c.cfg.MaxConnLifetime) if err != nil { return fmt.Errorf("MYSQL_MAX_CONN_LIFETIME: %w", err) } db.SetConnMaxLifetime(d) } if c.cfg.MaxConnIdleTime != "" { d, err := time.ParseDuration(c.cfg.MaxConnIdleTime) if err != nil { return fmt.Errorf("MYSQL_MAX_CONN_IDLE_TIME: %w", err) } db.SetConnMaxIdleTime(d) } c.mu.Lock() c.db = db c.mu.Unlock() return nil } func (c *mysqlComponent) OnStart() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := c.Ping(ctx); err != nil { return fmt.Errorf("mysql: ping failed: %w", err) } c.logger.Info("mysql: connected") return nil } func (c *mysqlComponent) OnStop() error { c.mu.Lock() defer c.mu.Unlock() if c.db != nil { c.logger.Info("mysql: closing pool") _ = c.db.Close() c.db = nil } return nil } func (c *mysqlComponent) Ping(ctx context.Context) error { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return fmt.Errorf("mysql: not initialized") } return db.PingContext(ctx) } func (c *mysqlComponent) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } c.mu.RLock() db := c.db c.mu.RUnlock() return db } func (c *mysqlComponent) Begin(ctx context.Context) (Tx, error) { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil, fmt.Errorf("mysql: not initialized") } tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } return &mysqlTx{Tx: tx}, nil } func (c *mysqlComponent) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.ExecContext(ctx, query, args...) } func (c *mysqlComponent) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryContext(ctx, query, args...) } func (c *mysqlComponent) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryRowContext(ctx, query, args...) } func (c *mysqlComponent) HandleError(err error) error { return HandleError(err) } // health.Checkable func (c *mysqlComponent) HealthCheck(ctx context.Context) error { return c.Ping(ctx) } func (c *mysqlComponent) Name() string { return "mysql" } func (c *mysqlComponent) Priority() health.Level { return health.LevelCritical } // --- mysqlTx --- type mysqlTx struct{ *sql.Tx } func (t *mysqlTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { return t.Tx.ExecContext(ctx, query, args...) } func (t *mysqlTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return t.Tx.QueryContext(ctx, query, args...) } func (t *mysqlTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { return t.Tx.QueryRowContext(ctx, query, args...) } func (t *mysqlTx) Commit() error { return t.Tx.Commit() } func (t *mysqlTx) Rollback() error { return t.Tx.Rollback() } // --- UnitOfWork --- type unitOfWork struct { logger logz.Logger client Client } // NewUnitOfWork returns a UnitOfWork backed by the given client. func NewUnitOfWork(logger logz.Logger, client Client) UnitOfWork { return &unitOfWork{logger: logger, client: client} } func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error { tx, err := u.client.Begin(ctx) if err != nil { return fmt.Errorf("mysql: begin transaction: %w", err) } ctx = context.WithValue(ctx, ctxTxKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { u.logger.Error("mysql: rollback failed", rbErr) } return err } return tx.Commit() }