package mysql import ( "context" "database/sql" "errors" "go/ast" "go/parser" "go/token" "os" "path/filepath" "strings" "testing" mysqldrv "github.com/go-sql-driver/mysql" "code.nochebuena.dev/einherjar/contracts/lifecycle" "code.nochebuena.dev/einherjar/contracts/logging" "code.nochebuena.dev/einherjar/contracts/observability" "code.nochebuena.dev/einherjar/core/xerrors" ) // Compile-time interface checks (CT-5 / I-8). var _ lifecycle.Component = (Component)(nil) var _ observability.Checkable = (Component)(nil) var _ Provider = (Component)(nil) // --- CT-6: at most one exported TypeSpec per non-test, non-doc file --- func TestAtMostOneExportedTypePerFile(t *testing.T) { fset := token.NewFileSet() pkgs, err := parser.ParseDir(fset, ".", func(fi os.FileInfo) bool { name := fi.Name() return !strings.HasSuffix(name, "_test.go") && name != "doc.go" }, 0) if err != nil { t.Fatalf("parse: %v", err) } for _, pkg := range pkgs { for path, file := range pkg.Files { base := filepath.Base(path) count := 0 for _, decl := range file.Decls { gd, ok := decl.(*ast.GenDecl) if !ok { continue } for _, spec := range gd.Specs { ts, ok := spec.(*ast.TypeSpec) if ok && ts.Name.IsExported() { count++ } } } if count > 1 { t.Errorf("%s: %d exported TypeSpecs (max 1)", base, count) } } } } // --- Config defaults (S-4) --- func TestDefaultConfig_OptionalFields(t *testing.T) { cfg := DefaultConfig() if cfg.Port == 0 { t.Error("Port must have a default") } if cfg.MaxConns == 0 { t.Error("MaxConns must have a default") } if cfg.Charset == "" { t.Error("Charset must have a default") } } // --- Config.DSN --- func TestConfig_DSN(t *testing.T) { cfg := Config{Host: "localhost", Port: 3306, User: "root", Password: "pass", Name: "mydb"} dsn := cfg.DSN() if dsn == "" { t.Fatal("DSN empty") } for _, want := range []string{"root", "localhost", "3306", "mydb"} { if !strings.Contains(dsn, want) { t.Errorf("DSN %q missing %q", dsn, want) } } } func TestConfig_DSN_Defaults(t *testing.T) { cfg := Config{Host: "h", Port: 3306, User: "u", Password: "p", Name: "db"} dsn := cfg.DSN() if !strings.Contains(dsn, "utf8mb4") { t.Error("DSN missing default charset utf8mb4") } } // --- HandleError --- func TestHandleError_Nil(t *testing.T) { if err := HandleError(nil); err != nil { t.Errorf("want nil, got %v", err) } } func TestHandleError_DuplicateEntry(t *testing.T) { assertXCode(t, HandleError(&mysqldrv.MySQLError{Number: 1062}), xerrors.ErrAlreadyExists) } func TestHandleError_ForeignKey_1452(t *testing.T) { assertXCode(t, HandleError(&mysqldrv.MySQLError{Number: 1452}), xerrors.ErrInvalidInput) } func TestHandleError_ForeignKey_1451(t *testing.T) { assertXCode(t, HandleError(&mysqldrv.MySQLError{Number: 1451}), xerrors.ErrInvalidInput) } func TestHandleError_ForeignKey_1216(t *testing.T) { assertXCode(t, HandleError(&mysqldrv.MySQLError{Number: 1216}), xerrors.ErrInvalidInput) } func TestHandleError_ForeignKey_1217(t *testing.T) { assertXCode(t, HandleError(&mysqldrv.MySQLError{Number: 1217}), xerrors.ErrInvalidInput) } func TestHandleError_NoRows(t *testing.T) { assertXCode(t, HandleError(sql.ErrNoRows), xerrors.ErrNotFound) } func TestHandleError_Generic(t *testing.T) { assertXCode(t, HandleError(errors.New("boom")), xerrors.ErrInternal) } // --- New --- func TestNew_NotNil(t *testing.T) { if New(newLogger(), Config{}) == nil { t.Fatal("New returned nil") } } // --- Component metadata --- func TestComponent_Name(t *testing.T) { c := New(newLogger(), Config{}) if c.Name() != "mysql" { t.Errorf("Name() = %q, want mysql", c.Name()) } } func TestComponent_Priority_IsCritical(t *testing.T) { c := New(newLogger(), Config{}) if c.Priority() != observability.LevelCritical { t.Error("Priority() must be LevelCritical") } } // --- Nil-DB safety --- func TestComponent_OnStop_NilDB(t *testing.T) { c := &mysqlImpl{logger: newLogger()} if err := c.OnStop(); err != nil { t.Errorf("OnStop with nil db: %v", err) } } func TestComponent_Begin_NilDB(t *testing.T) { c := &mysqlImpl{logger: newLogger()} _, err := c.Begin(context.Background()) if err == nil { t.Error("expected error for nil db") } } func TestComponent_BeginTx_NilDB(t *testing.T) { c := &mysqlImpl{logger: newLogger()} _, err := c.BeginTx(context.Background(), nil) if err == nil { t.Error("expected error for nil db") } } func TestComponent_Stats_NilDB(t *testing.T) { c := &mysqlImpl{logger: newLogger()} stats := c.Stats() if stats.MaxOpenConnections != 0 || stats.OpenConnections != 0 { t.Errorf("expected zero DBStats, got %+v", stats) } } func TestComponent_GetExecutor_ReturnsNil(t *testing.T) { c := &mysqlImpl{logger: newLogger()} exec := c.GetExecutor(context.Background()) if exec != nil { t.Error("GetExecutor should return nil when db is not initialized") } } func TestComponent_GetExecutor_ReturnsTx(t *testing.T) { tx := &mockTx{} ctx := context.WithValue(context.Background(), ctxTxKey{}, Executor(tx)) c := &mysqlImpl{logger: newLogger()} got := c.GetExecutor(ctx) if got != tx { t.Error("GetExecutor should return the injected Tx from context") } } // --- UnitOfWork --- func TestUnitOfWork_Commit(t *testing.T) { tx := &mockTx{} uow := NewUnitOfWork(newLogger(), &mockProvider{tx: tx}) if err := uow.Do(context.Background(), func(ctx context.Context) error { return nil }); err != nil { t.Fatalf("unexpected error: %v", err) } if !tx.committed { t.Error("expected Commit") } } func TestUnitOfWork_Rollback(t *testing.T) { tx := &mockTx{} uow := NewUnitOfWork(newLogger(), &mockProvider{tx: tx}) _ = uow.Do(context.Background(), func(ctx context.Context) error { return errors.New("fail") }) if !tx.rolledBack { t.Error("expected Rollback") } } func TestUnitOfWork_InjectsExecutor(t *testing.T) { tx := &mockTx{} p := &mockProvider{tx: tx} uow := NewUnitOfWork(newLogger(), p) var got Executor _ = uow.Do(context.Background(), func(ctx context.Context) error { got = p.GetExecutor(ctx) return nil }) if got != tx { t.Error("GetExecutor should return the injected Tx") } } func TestUnitOfWork_ReturnsBeginError(t *testing.T) { p := &mockProvider{beginErr: errors.New("connect refused")} uow := NewUnitOfWork(newLogger(), p) err := uow.Do(context.Background(), func(ctx context.Context) error { return nil }) if err == nil { t.Fatal("expected error when Begin fails") } } // --- helpers --- type stubLogger struct{} func newLogger() *stubLogger { return &stubLogger{} } func (s *stubLogger) Debug(msg string, args ...any) {} func (s *stubLogger) Info(msg string, args ...any) {} func (s *stubLogger) Warn(msg string, args ...any) {} func (s *stubLogger) Error(msg string, err error, args ...any) {} func (s *stubLogger) With(args ...any) logging.Logger { return s } func (s *stubLogger) WithContext(ctx context.Context) logging.Logger { return s } type mockTx struct { committed bool rolledBack bool } func (m *mockTx) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) { return nil, nil } func (m *mockTx) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { return nil, nil } func (m *mockTx) QueryRowContext(_ context.Context, _ string, _ ...any) *sql.Row { return nil } func (m *mockTx) Commit() error { m.committed = true; return nil } func (m *mockTx) Rollback() error { m.rolledBack = true; return nil } type mockProvider struct { tx *mockTx beginErr error } func (m *mockProvider) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } return nil } func (m *mockProvider) Begin(_ context.Context) (Tx, error) { if m.beginErr != nil { return nil, m.beginErr } return m.tx, nil } func (m *mockProvider) BeginTx(_ context.Context, _ *sql.TxOptions) (Tx, error) { if m.beginErr != nil { return nil, m.beginErr } return m.tx, nil } func (m *mockProvider) Ping(_ context.Context) error { return nil } func (m *mockProvider) HandleError(err error) error { return HandleError(err) } func assertXCode(t *testing.T, err error, want xerrors.Code) { t.Helper() var xe *xerrors.Err if !errors.As(err, &xe) { t.Fatalf("expected *xerrors.Err, got %T: %v", err, err) } if xe.Code() != want { t.Errorf("want code %v, got %v", want, xe.Code()) } }