package mysql import ( "context" "database/sql" "errors" "testing" mysqldrv "github.com/go-sql-driver/mysql" "code.nochebuena.dev/go/health" "code.nochebuena.dev/go/logz" "code.nochebuena.dev/go/xerrors" ) func newLogger() logz.Logger { return logz.New(logz.Options{}) } // --- New / name / priority --- func TestNew(t *testing.T) { if New(newLogger(), Config{}) == nil { t.Fatal("New returned nil") } } func TestComponent_Name(t *testing.T) { c := New(newLogger(), Config{}).(health.Checkable) if c.Name() != "mysql" { t.Errorf("want mysql, got %s", c.Name()) } } func TestComponent_Priority(t *testing.T) { c := New(newLogger(), Config{}).(health.Checkable) if c.Priority() != health.LevelCritical { t.Error("Priority() != LevelCritical") } } func TestComponent_OnStop_NilDB(t *testing.T) { c := &mysqlComponent{logger: newLogger()} if err := c.OnStop(); err != nil { t.Errorf("OnStop with nil db: %v", err) } } // --- Config.DSN --- func TestConfig_DSN(t *testing.T) { cfg := Config{Host: "localhost", Port: 3306, User: "root", Password: "pass", Name: "mydb"} dsn := cfg.DSN() if dsn == "" { t.Fatal("DSN empty") } for _, want := range []string{"root", "localhost", "3306", "mydb"} { if !strContains(dsn, want) { t.Errorf("DSN %q missing %q", dsn, want) } } } // --- HandleError --- func TestHandleError_Nil(t *testing.T) { if err := HandleError(nil); err != nil { t.Errorf("want nil, got %v", err) } } func TestHandleError_DuplicateEntry(t *testing.T) { assertCode(t, HandleError(&mysqldrv.MySQLError{Number: 1062}), xerrors.ErrAlreadyExists) } func TestHandleError_ForeignKey(t *testing.T) { assertCode(t, HandleError(&mysqldrv.MySQLError{Number: 1452}), xerrors.ErrInvalidInput) } func TestHandleError_NoRows(t *testing.T) { assertCode(t, HandleError(sql.ErrNoRows), xerrors.ErrNotFound) } func TestHandleError_Generic(t *testing.T) { assertCode(t, HandleError(errors.New("boom")), xerrors.ErrInternal) } // --- UnitOfWork --- type mockTx struct{ committed, rolledBack bool } func (m *mockTx) ExecContext(ctx context.Context, q string, args ...any) (sql.Result, error) { return nil, nil } func (m *mockTx) QueryContext(ctx context.Context, q string, args ...any) (*sql.Rows, error) { return nil, nil } func (m *mockTx) QueryRowContext(ctx context.Context, q string, args ...any) *sql.Row { return nil } func (m *mockTx) Commit() error { m.committed = true; return nil } func (m *mockTx) Rollback() error { m.rolledBack = true; return nil } type mockClient struct{ tx *mockTx } func (m *mockClient) Begin(ctx context.Context) (Tx, error) { return m.tx, nil } func (m *mockClient) Ping(ctx context.Context) error { return nil } func (m *mockClient) HandleError(err error) error { return HandleError(err) } func (m *mockClient) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } return nil } func TestUnitOfWork_Commit(t *testing.T) { tx := &mockTx{} uow := NewUnitOfWork(newLogger(), &mockClient{tx: tx}) if err := uow.Do(context.Background(), func(ctx context.Context) error { return nil }); err != nil { t.Fatalf("unexpected error: %v", err) } if !tx.committed { t.Error("expected Commit") } } func TestUnitOfWork_Rollback(t *testing.T) { tx := &mockTx{} uow := NewUnitOfWork(newLogger(), &mockClient{tx: tx}) _ = uow.Do(context.Background(), func(ctx context.Context) error { return errors.New("fail") }) if !tx.rolledBack { t.Error("expected Rollback") } } func TestUnitOfWork_InjectsExecutor(t *testing.T) { tx := &mockTx{} client := &mockClient{tx: tx} uow := NewUnitOfWork(newLogger(), client) var got Executor _ = uow.Do(context.Background(), func(ctx context.Context) error { got = client.GetExecutor(ctx) return nil }) if got != tx { t.Error("GetExecutor should return the injected Tx") } } // --- helpers --- func assertCode(t *testing.T, err error, want xerrors.Code) { t.Helper() var xe *xerrors.Err if !errors.As(err, &xe) { t.Fatalf("expected *xerrors.Err, got %T: %v", err, err) } if xe.Code() != want { t.Errorf("want %s, got %s", want, xe.Code()) } } func strContains(s, sub string) bool { for i := 0; i <= len(s)-len(sub); i++ { if s[i:i+len(sub)] == sub { return true } } return false }