package mysql import ( "context" "database/sql" "fmt" "net/url" "sync" "time" _ "github.com/go-sql-driver/mysql" // register driver "code.nochebuena.dev/go/health" "code.nochebuena.dev/go/launcher" "code.nochebuena.dev/go/logz" ) // Executor defines operations shared by pool and transaction. type Executor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } // Tx extends Executor with commit/rollback (no ctx — sql.Tx limitation). type Tx interface { Executor Commit() error Rollback() error } // Client is the database access interface. type Client interface { GetExecutor(ctx context.Context) Executor Begin(ctx context.Context) (Tx, error) Ping(ctx context.Context) error HandleError(err error) error } // Component bundles launcher lifecycle, health check, and database client. type Component interface { launcher.Component health.Checkable Client } // UnitOfWork wraps operations in a single database transaction. type UnitOfWork interface { Do(ctx context.Context, fn func(ctx context.Context) error) error } // Config holds MySQL connection settings. type Config struct { Host string `env:"MYSQL_HOST,required"` Port int `env:"MYSQL_PORT" envDefault:"3306"` User string `env:"MYSQL_USER,required"` Password string `env:"MYSQL_PASSWORD,required"` Name string `env:"MYSQL_NAME,required"` MaxConns int `env:"MYSQL_MAX_CONNS" envDefault:"5"` MinConns int `env:"MYSQL_MIN_CONNS" envDefault:"2"` MaxConnLifetime string `env:"MYSQL_MAX_CONN_LIFETIME" envDefault:"1h"` MaxConnIdleTime string `env:"MYSQL_MAX_CONN_IDLE_TIME" envDefault:"30m"` } // DSN constructs a MySQL DSN from the configuration. func (c Config) DSN() string { u := &url.URL{ Scheme: "mysql", User: url.UserPassword(c.User, c.Password), Host: fmt.Sprintf("%s:%d", c.Host, c.Port), Path: "/" + c.Name, } q := u.Query() q.Set("parseTime", "true") q.Set("loc", "UTC") u.RawQuery = q.Encode() // go-sql-driver uses user:pass@tcp(host:port)/db?params return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", c.User, c.Password, c.Host, c.Port, c.Name, q.Encode()) } // ctxTxKey is the context key for the active transaction. type ctxTxKey struct{} // --- mysqlComponent --- type mysqlComponent struct { logger logz.Logger cfg Config db *sql.DB mu sync.RWMutex } // New returns a mysql Component. Call lc.Append(db) to manage its lifecycle. func New(logger logz.Logger, cfg Config) Component { return &mysqlComponent{logger: logger, cfg: cfg} } func (c *mysqlComponent) OnInit() error { db, err := sql.Open("mysql", c.cfg.DSN()) if err != nil { return fmt.Errorf("mysql: open: %w", err) } db.SetMaxOpenConns(c.cfg.MaxConns) db.SetMaxIdleConns(c.cfg.MinConns) if c.cfg.MaxConnLifetime != "" { d, err := time.ParseDuration(c.cfg.MaxConnLifetime) if err != nil { return fmt.Errorf("MYSQL_MAX_CONN_LIFETIME: %w", err) } db.SetConnMaxLifetime(d) } if c.cfg.MaxConnIdleTime != "" { d, err := time.ParseDuration(c.cfg.MaxConnIdleTime) if err != nil { return fmt.Errorf("MYSQL_MAX_CONN_IDLE_TIME: %w", err) } db.SetConnMaxIdleTime(d) } c.mu.Lock() c.db = db c.mu.Unlock() return nil } func (c *mysqlComponent) OnStart() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := c.Ping(ctx); err != nil { return fmt.Errorf("mysql: ping failed: %w", err) } c.logger.Info("mysql: connected") return nil } func (c *mysqlComponent) OnStop() error { c.mu.Lock() defer c.mu.Unlock() if c.db != nil { c.logger.Info("mysql: closing pool") _ = c.db.Close() c.db = nil } return nil } func (c *mysqlComponent) Ping(ctx context.Context) error { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return fmt.Errorf("mysql: not initialized") } return db.PingContext(ctx) } func (c *mysqlComponent) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } c.mu.RLock() db := c.db c.mu.RUnlock() return db } func (c *mysqlComponent) Begin(ctx context.Context) (Tx, error) { c.mu.RLock() db := c.db c.mu.RUnlock() if db == nil { return nil, fmt.Errorf("mysql: not initialized") } tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } return &mysqlTx{Tx: tx}, nil } func (c *mysqlComponent) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.ExecContext(ctx, query, args...) } func (c *mysqlComponent) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryContext(ctx, query, args...) } func (c *mysqlComponent) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { c.mu.RLock() db := c.db c.mu.RUnlock() return db.QueryRowContext(ctx, query, args...) } func (c *mysqlComponent) HandleError(err error) error { return HandleError(err) } // health.Checkable func (c *mysqlComponent) HealthCheck(ctx context.Context) error { return c.Ping(ctx) } func (c *mysqlComponent) Name() string { return "mysql" } func (c *mysqlComponent) Priority() health.Level { return health.LevelCritical } // --- mysqlTx --- type mysqlTx struct{ *sql.Tx } func (t *mysqlTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { return t.Tx.ExecContext(ctx, query, args...) } func (t *mysqlTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return t.Tx.QueryContext(ctx, query, args...) } func (t *mysqlTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { return t.Tx.QueryRowContext(ctx, query, args...) } func (t *mysqlTx) Commit() error { return t.Tx.Commit() } func (t *mysqlTx) Rollback() error { return t.Tx.Rollback() } // --- UnitOfWork --- type unitOfWork struct { logger logz.Logger client Client } // NewUnitOfWork returns a UnitOfWork backed by the given client. func NewUnitOfWork(logger logz.Logger, client Client) UnitOfWork { return &unitOfWork{logger: logger, client: client} } func (u *unitOfWork) Do(ctx context.Context, fn func(ctx context.Context) error) error { tx, err := u.client.Begin(ctx) if err != nil { return fmt.Errorf("mysql: begin transaction: %w", err) } ctx = context.WithValue(ctx, ctxTxKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { u.logger.Error("mysql: rollback failed", rbErr) } return err } return tx.Commit() }