package postgres import ( "context" "errors" "testing" "github.com/jackc/pgerrcode" pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "code.nochebuena.dev/go/health" "code.nochebuena.dev/go/logz" "code.nochebuena.dev/go/xerrors" ) func newLogger() logz.Logger { return logz.New(logz.Options{}) } // --- New / name / priority --- func TestNew(t *testing.T) { if New(newLogger(), Config{}) == nil { t.Fatal("New returned nil") } } func TestComponent_Name(t *testing.T) { c := New(newLogger(), Config{}).(health.Checkable) if c.Name() != "postgres" { t.Error("Name() != postgres") } } func TestComponent_Priority(t *testing.T) { c := New(newLogger(), Config{}).(health.Checkable) if c.Priority() != health.LevelCritical { t.Error("Priority() != LevelCritical") } } func TestComponent_OnStop_NilPool(t *testing.T) { c := &pgComponent{logger: newLogger()} if err := c.OnStop(); err != nil { t.Errorf("OnStop with nil pool: %v", err) } } // --- Config.DSN --- func TestConfig_DSN(t *testing.T) { cfg := Config{ Host: "localhost", Port: 5432, User: "user", Password: "pass", Name: "mydb", SSLMode: "disable", Timezone: "UTC", } dsn := cfg.DSN() for _, want := range []string{"localhost:5432", "mydb", "sslmode=disable"} { if !strContains(dsn, want) { t.Errorf("DSN %q missing %q", dsn, want) } } } // --- HandleError --- func TestHandleError_Nil(t *testing.T) { if err := HandleError(nil); err != nil { t.Errorf("want nil, got %v", err) } } func TestHandleError_UniqueViolation(t *testing.T) { assertCode(t, HandleError(&pgconn.PgError{Code: pgerrcode.UniqueViolation}), xerrors.ErrAlreadyExists) } func TestHandleError_ForeignKey(t *testing.T) { assertCode(t, HandleError(&pgconn.PgError{Code: pgerrcode.ForeignKeyViolation}), xerrors.ErrInvalidInput) } func TestHandleError_CheckViolation(t *testing.T) { assertCode(t, HandleError(&pgconn.PgError{Code: pgerrcode.CheckViolation}), xerrors.ErrInvalidInput) } func TestHandleError_NoRows(t *testing.T) { assertCode(t, HandleError(pgx.ErrNoRows), xerrors.ErrNotFound) } func TestHandleError_Generic(t *testing.T) { assertCode(t, HandleError(errors.New("boom")), xerrors.ErrInternal) } // --- UnitOfWork --- type mockTx struct{ committed, rolledBack bool } func (m *mockTx) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { return pgconn.CommandTag{}, nil } func (m *mockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return nil, nil } func (m *mockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return nil } func (m *mockTx) Commit(ctx context.Context) error { m.committed = true; return nil } func (m *mockTx) Rollback(ctx context.Context) error { m.rolledBack = true; return nil } type mockClient struct{ tx *mockTx } func (m *mockClient) Begin(ctx context.Context) (Tx, error) { return m.tx, nil } func (m *mockClient) Ping(ctx context.Context) error { return nil } func (m *mockClient) HandleError(err error) error { return HandleError(err) } func (m *mockClient) GetExecutor(ctx context.Context) Executor { if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok { return tx } return nil } func TestUnitOfWork_Commit(t *testing.T) { tx := &mockTx{} uow := NewUnitOfWork(newLogger(), &mockClient{tx: tx}) if err := uow.Do(context.Background(), func(ctx context.Context) error { return nil }); err != nil { t.Fatalf("unexpected error: %v", err) } if !tx.committed { t.Error("expected Commit") } } func TestUnitOfWork_Rollback(t *testing.T) { tx := &mockTx{} uow := NewUnitOfWork(newLogger(), &mockClient{tx: tx}) _ = uow.Do(context.Background(), func(ctx context.Context) error { return errors.New("fail") }) if !tx.rolledBack { t.Error("expected Rollback") } } func TestUnitOfWork_InjectsExecutor(t *testing.T) { tx := &mockTx{} client := &mockClient{tx: tx} uow := NewUnitOfWork(newLogger(), client) var got Executor _ = uow.Do(context.Background(), func(ctx context.Context) error { got = client.GetExecutor(ctx) return nil }) if got != tx { t.Error("GetExecutor should return the injected Tx") } } // --- helpers --- func assertCode(t *testing.T, err error, want xerrors.Code) { t.Helper() var xe *xerrors.Err if !errors.As(err, &xe) { t.Fatalf("expected *xerrors.Err, got %T: %v", err, err) } if xe.Code() != want { t.Errorf("want code %s, got %s", want, xe.Code()) } } func strContains(s, sub string) bool { if len(sub) == 0 { return true } for i := 0; i <= len(s)-len(sub); i++ { if s[i:i+len(sub)] == sub { return true } } return false }