285 lines
7.6 KiB
Go
285 lines
7.6 KiB
Go
|
|
package postgres
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"go/ast"
|
||
|
|
"go/parser"
|
||
|
|
"go/token"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/jackc/pgerrcode"
|
||
|
|
pgx "github.com/jackc/pgx/v5"
|
||
|
|
"github.com/jackc/pgx/v5/pgconn"
|
||
|
|
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/lifecycle"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/logging"
|
||
|
|
"code.nochebuena.dev/einherjar/contracts/observability"
|
||
|
|
"code.nochebuena.dev/einherjar/core/xerrors"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Compile-time interface checks (CT-5 / I-8).
|
||
|
|
var _ lifecycle.Component = (Component)(nil)
|
||
|
|
var _ observability.Checkable = (Component)(nil)
|
||
|
|
var _ Provider = (Component)(nil)
|
||
|
|
|
||
|
|
// --- CT-6: at most one exported TypeSpec per non-test, non-doc file ---
|
||
|
|
|
||
|
|
func TestAtMostOneExportedTypePerFile(t *testing.T) {
|
||
|
|
fset := token.NewFileSet()
|
||
|
|
pkgs, err := parser.ParseDir(fset, ".", func(fi os.FileInfo) bool {
|
||
|
|
name := fi.Name()
|
||
|
|
return !strings.HasSuffix(name, "_test.go") && name != "doc.go"
|
||
|
|
}, 0)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("parse: %v", err)
|
||
|
|
}
|
||
|
|
for _, pkg := range pkgs {
|
||
|
|
for path, file := range pkg.Files {
|
||
|
|
base := filepath.Base(path)
|
||
|
|
count := 0
|
||
|
|
for _, decl := range file.Decls {
|
||
|
|
gd, ok := decl.(*ast.GenDecl)
|
||
|
|
if !ok {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
for _, spec := range gd.Specs {
|
||
|
|
ts, ok := spec.(*ast.TypeSpec)
|
||
|
|
if ok && ts.Name.IsExported() {
|
||
|
|
count++
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if count > 1 {
|
||
|
|
t.Errorf("%s: %d exported TypeSpecs (max 1)", base, count)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Config defaults (S-4) ---
|
||
|
|
|
||
|
|
func TestDefaultConfig_OptionalFields(t *testing.T) {
|
||
|
|
cfg := DefaultConfig()
|
||
|
|
if cfg.Port == 0 {
|
||
|
|
t.Error("Port must have a default")
|
||
|
|
}
|
||
|
|
if cfg.SSLMode == "" {
|
||
|
|
t.Error("SSLMode must have a default")
|
||
|
|
}
|
||
|
|
if cfg.Timezone == "" {
|
||
|
|
t.Error("Timezone must have a default")
|
||
|
|
}
|
||
|
|
if cfg.MaxConns == 0 {
|
||
|
|
t.Error("MaxConns must have a default")
|
||
|
|
}
|
||
|
|
if cfg.MinConns == 0 {
|
||
|
|
t.Error("MinConns must have a default")
|
||
|
|
}
|
||
|
|
if cfg.MaxConnLifetime == "" {
|
||
|
|
t.Error("MaxConnLifetime must have a default")
|
||
|
|
}
|
||
|
|
if cfg.MaxConnIdleTime == "" {
|
||
|
|
t.Error("MaxConnIdleTime must have a default")
|
||
|
|
}
|
||
|
|
if cfg.HealthCheckPeriod == "" {
|
||
|
|
t.Error("HealthCheckPeriod must have a default")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- Config.DSN ---
|
||
|
|
|
||
|
|
func TestConfig_DSN(t *testing.T) {
|
||
|
|
cfg := Config{
|
||
|
|
Host: "localhost", Port: 5432,
|
||
|
|
User: "user", Password: "pass",
|
||
|
|
Name: "mydb", SSLMode: "disable", Timezone: "UTC",
|
||
|
|
}
|
||
|
|
dsn := cfg.DSN()
|
||
|
|
for _, want := range []string{"localhost:5432", "mydb", "sslmode=disable"} {
|
||
|
|
if !strings.Contains(dsn, want) {
|
||
|
|
t.Errorf("DSN %q missing %q", dsn, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- HandleError ---
|
||
|
|
|
||
|
|
func TestHandleError_Nil(t *testing.T) {
|
||
|
|
if err := HandleError(nil); err != nil {
|
||
|
|
t.Errorf("want nil, got %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleError_UniqueViolation(t *testing.T) {
|
||
|
|
assertXCode(t, HandleError(&pgconn.PgError{Code: pgerrcode.UniqueViolation}), xerrors.ErrAlreadyExists)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleError_ForeignKey(t *testing.T) {
|
||
|
|
assertXCode(t, HandleError(&pgconn.PgError{Code: pgerrcode.ForeignKeyViolation}), xerrors.ErrInvalidInput)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleError_CheckViolation(t *testing.T) {
|
||
|
|
assertXCode(t, HandleError(&pgconn.PgError{Code: pgerrcode.CheckViolation}), xerrors.ErrInvalidInput)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleError_NoRows(t *testing.T) {
|
||
|
|
assertXCode(t, HandleError(pgx.ErrNoRows), xerrors.ErrNotFound)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHandleError_Generic(t *testing.T) {
|
||
|
|
assertXCode(t, HandleError(errors.New("boom")), xerrors.ErrInternal)
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- New / name / priority / stats ---
|
||
|
|
|
||
|
|
func TestNew_NotNil(t *testing.T) {
|
||
|
|
if New(newLogger(), Config{}) == nil {
|
||
|
|
t.Fatal("New returned nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestComponent_Name(t *testing.T) {
|
||
|
|
c := New(newLogger(), Config{})
|
||
|
|
if c.Name() != "postgres" {
|
||
|
|
t.Errorf("Name() = %q, want %q", c.Name(), "postgres")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestComponent_Priority(t *testing.T) {
|
||
|
|
c := New(newLogger(), Config{})
|
||
|
|
if c.Priority() != observability.LevelCritical {
|
||
|
|
t.Error("Priority() != LevelCritical")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestComponent_Stats_BeforeInit(t *testing.T) {
|
||
|
|
c := New(newLogger(), Config{})
|
||
|
|
if c.Stats() == nil {
|
||
|
|
t.Error("Stats() must return non-nil before pool init")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestComponent_OnStop_NilPool(t *testing.T) {
|
||
|
|
c := &pgComponent{logger: newLogger()}
|
||
|
|
if err := c.OnStop(); err != nil {
|
||
|
|
t.Errorf("OnStop with nil pool: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestComponent_BeginTx_NilPool(t *testing.T) {
|
||
|
|
c := &pgComponent{logger: newLogger()}
|
||
|
|
_, err := c.BeginTx(context.Background(), pgx.TxOptions{})
|
||
|
|
if err == nil {
|
||
|
|
t.Error("expected error for nil pool")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- UnitOfWork ---
|
||
|
|
|
||
|
|
func TestUnitOfWork_Commit(t *testing.T) {
|
||
|
|
tx := &mockTx{}
|
||
|
|
uow := NewUnitOfWork(newLogger(), &mockProvider{tx: tx})
|
||
|
|
if err := uow.Do(context.Background(), func(ctx context.Context) error { return nil }); err != nil {
|
||
|
|
t.Fatalf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
if !tx.committed {
|
||
|
|
t.Error("expected Commit to be called")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestUnitOfWork_Rollback(t *testing.T) {
|
||
|
|
tx := &mockTx{}
|
||
|
|
uow := NewUnitOfWork(newLogger(), &mockProvider{tx: tx})
|
||
|
|
_ = uow.Do(context.Background(), func(ctx context.Context) error { return errors.New("fail") })
|
||
|
|
if !tx.rolledBack {
|
||
|
|
t.Error("expected Rollback to be called")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestUnitOfWork_InjectsExecutor(t *testing.T) {
|
||
|
|
tx := &mockTx{}
|
||
|
|
client := &mockProvider{tx: tx}
|
||
|
|
uow := NewUnitOfWork(newLogger(), client)
|
||
|
|
var got Executor
|
||
|
|
_ = uow.Do(context.Background(), func(ctx context.Context) error {
|
||
|
|
got = client.GetExecutor(ctx)
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
if got != tx {
|
||
|
|
t.Error("GetExecutor should return the injected Tx inside Do")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestUnitOfWork_ReturnsBeginError(t *testing.T) {
|
||
|
|
client := &mockProvider{beginErr: errors.New("connection lost")}
|
||
|
|
uow := NewUnitOfWork(newLogger(), client)
|
||
|
|
if err := uow.Do(context.Background(), func(ctx context.Context) error { return nil }); err == nil {
|
||
|
|
t.Error("expected error when Begin fails")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- helpers ---
|
||
|
|
|
||
|
|
func assertXCode(t *testing.T, err error, want xerrors.Code) {
|
||
|
|
t.Helper()
|
||
|
|
var xe *xerrors.Err
|
||
|
|
if !errors.As(err, &xe) {
|
||
|
|
t.Fatalf("expected *xerrors.Err, got %T: %v", err, err)
|
||
|
|
}
|
||
|
|
if xe.Code() != want {
|
||
|
|
t.Errorf("want code %s, got %s", want, xe.Code())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// --- stubs ---
|
||
|
|
|
||
|
|
type stubLogger struct{}
|
||
|
|
|
||
|
|
func newLogger() *stubLogger { return &stubLogger{} }
|
||
|
|
|
||
|
|
func (s *stubLogger) Debug(msg string, args ...any) {}
|
||
|
|
func (s *stubLogger) Info(msg string, args ...any) {}
|
||
|
|
func (s *stubLogger) Warn(msg string, args ...any) {}
|
||
|
|
func (s *stubLogger) Error(msg string, err error, args ...any) {}
|
||
|
|
func (s *stubLogger) With(args ...any) logging.Logger { return s }
|
||
|
|
func (s *stubLogger) WithContext(ctx context.Context) logging.Logger { return s }
|
||
|
|
|
||
|
|
type mockTx struct{ committed, rolledBack bool }
|
||
|
|
|
||
|
|
func (m *mockTx) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
||
|
|
return pgconn.CommandTag{}, nil
|
||
|
|
}
|
||
|
|
func (m *mockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
func (m *mockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { return nil }
|
||
|
|
func (m *mockTx) Commit(ctx context.Context) error { m.committed = true; return nil }
|
||
|
|
func (m *mockTx) Rollback(ctx context.Context) error { m.rolledBack = true; return nil }
|
||
|
|
|
||
|
|
type mockProvider struct {
|
||
|
|
tx *mockTx
|
||
|
|
beginErr error
|
||
|
|
}
|
||
|
|
|
||
|
|
func (m *mockProvider) Begin(ctx context.Context) (Tx, error) {
|
||
|
|
if m.beginErr != nil {
|
||
|
|
return nil, m.beginErr
|
||
|
|
}
|
||
|
|
return m.tx, nil
|
||
|
|
}
|
||
|
|
func (m *mockProvider) BeginTx(ctx context.Context, opts pgx.TxOptions) (Tx, error) {
|
||
|
|
return m.Begin(ctx)
|
||
|
|
}
|
||
|
|
func (m *mockProvider) Ping(ctx context.Context) error { return nil }
|
||
|
|
func (m *mockProvider) HandleError(err error) error { return HandleError(err) }
|
||
|
|
func (m *mockProvider) GetExecutor(ctx context.Context) Executor {
|
||
|
|
if tx, ok := ctx.Value(ctxTxKey{}).(Executor); ok {
|
||
|
|
return tx
|
||
|
|
}
|
||
|
|
return m.tx
|
||
|
|
}
|