Files
mysql/mysql.go

251 lines
6.7 KiB
Go
Raw Normal View History

package mysql
import (
"context"
"database/sql"
"fmt"
"net/url"
"sync"
"time"
_ "github.com/go-sql-driver/mysql" // register driver
"code.nochebuena.dev/go/health"
"code.nochebuena.dev/go/launcher"
"code.nochebuena.dev/go/logz"
)
// Executor defines operations shared by pool and transaction.
type Executor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
// Tx extends Executor with commit/rollback (no ctx — sql.Tx limitation).
type Tx interface {
Executor
Commit() error
Rollback() error
}
// Client is the database access interface.
type Client interface {
GetExecutor(ctx context.Context) Executor
Begin(ctx context.Context) (Tx, error)
Ping(ctx context.Context) error
HandleError(err error) error
}
// Component bundles launcher lifecycle, health check, and database client.
type Component interface {
launcher.Component
health.Checkable
Client
}
// UnitOfWork wraps operations in a single database transaction.
type UnitOfWork interface {
Do(ctx context.Context, fn func(ctx context.Context) error) error
}
// Config holds MySQL connection settings.
type Config struct {
Host string `env:"MYSQL_HOST,required"`
Port int `env:"MYSQL_PORT" envDefault:"3306"`
User string `env:"MYSQL_USER,required"`
Password string `env:"MYSQL_PASSWORD,required"`
Name string `env:"MYSQL_NAME,required"`
MaxConns int `env:"MYSQL_MAX_CONNS" envDefault:"5"`
MinConns int `env:"MYSQL_MIN_CONNS" envDefault:"2"`
MaxConnLifetime string `env:"MYSQL_MAX_CONN_LIFETIME" envDefault:"1h"`
MaxConnIdleTime string `env:"MYSQL_MAX_CONN_IDLE_TIME" envDefault:"30m"`
}
// DSN constructs a MySQL DSN from the configuration.
func (c Config) DSN() string {
u := &url.URL{
Scheme: "mysql",
User: url.UserPassword(c.User, c.Password),
Host: fmt.Sprintf("%s:%d", c.Host, c.Port),
Path: "/" + c.Name,
}
q := u.Query()
q.Set("parseTime", "true")
q.Set("loc", "UTC")
u.RawQuery = q.Encode()
// go-sql-driver uses user:pass@tcp(host:port)/db?params
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s",
c.User, c.Password, c.Host, c.Port, c.Name, q.Encode())
}
// ctxTxKey is the context key for the active transaction.
type ctxTxKey struct{}
// --- mysqlComponent ---
type mysqlComponent struct {
logger logz.Logger
cfg Config
db *sql.DB
mu sync.RWMutex
}
// New returns a mysql Component. Call lc.Append(db) to manage its lifecycle.
func New(logger logz.Logger, cfg Config) Component {
return &mysqlComponent{logger: logger, cfg: cfg}
}
func (c *mysqlComponent) OnInit() error {
db, err := sql.Open("mysql", c.cfg.DSN())
if err != nil {
return fmt.Errorf("mysql: open: %w", err)
}
db.SetMaxOpenConns(c.cfg.MaxConns)
db.SetMaxIdleConns(c.cfg.MinConns)
if c.cfg.MaxConnLifetime != "" {
d, err := time.ParseDuration(c.cfg.MaxConnLifetime)
if err != nil {
return fmt.Errorf("MYSQL_MAX_CONN_LIFETIME: %w", err)
}
db.SetConnMaxLifetime(d)
}
if c.cfg.MaxConnIdleTime != "" {
d, err := time.ParseDuration(c.cfg.MaxConnIdleTime)
if err != nil {
return fmt.Errorf("MYSQL_MAX_CONN_IDLE_TIME: %w", err)
}
db.SetConnMaxIdleTime(d)
}
c.mu.Lock()
c.db = db
c.mu.Unlock()
return nil
}
func (c *mysqlComponent) OnStart() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := c.Ping(ctx); err != nil {
return fmt.Errorf("mysql: ping failed: %w", err)
}
c.logger.Info("mysql: connected")
return nil
}
func (c *mysqlComponent) OnStop() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.db != nil {
c.logger.Info("mysql: closing pool")
_ = c.db.Close()
c.db = nil
}
return nil
}
func (c *mysqlComponent) Ping(ctx context.Context) error {
c.mu.RLock()
db := c.db
c.mu.RUnlock()
if db == nil {
return fmt.Errorf("mysql: not initialized")
}
return db.PingContext(ctx)
}
func (c *mysqlComponent) GetExecutor(ctx context.Context) Executor {
if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok {
return tx
}
c.mu.RLock()
db := c.db
c.mu.RUnlock()
return db
}
func (c *mysqlComponent) Begin(ctx context.Context) (Tx, error) {
c.mu.RLock()
db := c.db
c.mu.RUnlock()
if db == nil {
return nil, fmt.Errorf("mysql: not initialized")
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &mysqlTx{Tx: tx}, nil
}
func (c *mysqlComponent) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
c.mu.RLock()
db := c.db
c.mu.RUnlock()
return db.ExecContext(ctx, query, args...)
}
func (c *mysqlComponent) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
c.mu.RLock()
db := c.db
c.mu.RUnlock()
return db.QueryContext(ctx, query, args...)
}
func (c *mysqlComponent) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
c.mu.RLock()
db := c.db
c.mu.RUnlock()
return db.QueryRowContext(ctx, query, args...)
}
func (c *mysqlComponent) HandleError(err error) error { return HandleError(err) }
// health.Checkable
func (c *mysqlComponent) HealthCheck(ctx context.Context) error { return c.Ping(ctx) }
func (c *mysqlComponent) Name() string { return "mysql" }
func (c *mysqlComponent) Priority() health.Level { return health.LevelCritical }
// --- mysqlTx ---
type mysqlTx struct{ *sql.Tx }
func (t *mysqlTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return t.Tx.ExecContext(ctx, query, args...)
}
func (t *mysqlTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return t.Tx.QueryContext(ctx, query, args...)
}
func (t *mysqlTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return t.Tx.QueryRowContext(ctx, query, args...)
}
func (t *mysqlTx) Commit() error { return t.Tx.Commit() }
func (t *mysqlTx) Rollback() error { return t.Tx.Rollback() }
// --- UnitOfWork ---
type unitOfWork struct {
logger logz.Logger
client Client
}
// NewUnitOfWork returns a UnitOfWork backed by the given client.
func NewUnitOfWork(logger logz.Logger, client Client) UnitOfWork {
return &unitOfWork{logger: logger, client: client}
}
func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error {
tx, err := u.client.Begin(ctx)
if err != nil {
return fmt.Errorf("mysql: begin transaction: %w", err)
}
ctx = context.WithValue(ctx, ctxTxKey{}, tx)
if err := fn(ctx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
u.logger.Error("mysql: rollback failed", rbErr)
}
return err
}
return tx.Commit()
}