diff --git a/database.go b/database.go index 18b3856..b756e9d 100644 --- a/database.go +++ b/database.go @@ -1,8 +1,11 @@ package dbschema import ( + // External + "github.com/jackc/pgx/v5/pgxpool" + // Standard - "database/sql" + "context" "fmt" ) @@ -13,17 +16,21 @@ func newDatabase(host string, port int, dbName, user, pass string) (dbase Databa dbase.Username = user dbase.Password = pass - dbase.db, err = sql.Open("postgres", dbase.sqlConnString()) + dbase.db, err = pgxpool.New(context.Background(), dbase.sqlConnString()) return }// }}} +func databaseFromInstance(db *pgxpool.Pool) (dbase Database, err error) { + dbase.db = db + return +} func (dbase Database) sqlConnString() string {// {{{ return fmt.Sprintf( - "postgresql://%s:%s@%s:%d/%s?sslmode=disable", - dbase.Username, - dbase.Password, + "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", dbase.Host, dbase.Port, + dbase.Username, + dbase.Password, dbase.DbName, ) }// }}} diff --git a/schema.go b/schema.go index 836a836..8ee3646 100644 --- a/schema.go +++ b/schema.go @@ -1,15 +1,33 @@ +/* +Package dbschema is used to keep the SQL schema up to date. + + func sqlProvider(dbName string, version int) (sql []byte, found bool) { + // read an SQL file and return the contents + return + } + + upgrader := dbschema.NewUpgrader() + upgrader.SetSqlCallback(sqlProvider) + + if err = upgrader.AddDatabase("127.0.0.1", 5432, "foo", "postgres", "password"); err != nil { + panic(err) + } + + if err = upgrader.Run(); err != nil { + panic(err) + } +*/ package dbschema import ( // External - _ "github.com/lib/pq" - - // Standard - "database/sql" + "github.com/jackc/pgx/v5/pgxpool" ) +// An upgrader verifies the schema for one or more databases and upgrades them if possible. type Upgrader struct { - databases map[string]Database + schema string + databases map[string]Database logCallback func(string, string) sqlCallback func(string, int) ([]byte, bool) } @@ -21,7 +39,7 @@ type Database struct { Username string Password string - db *sql.DB + db *pgxpool.Pool upgrader *Upgrader } diff --git a/upgrader.go b/upgrader.go index 13e7c38..c9670ab 100644 --- a/upgrader.go +++ b/upgrader.go @@ -1,34 +1,89 @@ package dbschema import ( + // External + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + // Standard - "database/sql" + "context" "fmt" ) -func defaultCallback(topic, msg string) {// {{{ +func defaultCallback(topic, msg string) { // {{{ fmt.Printf("[%s] %s\n", topic, msg) -}// }}} -func NewUpgrader(host string, port int, dbName, user, pass string) (upgrader Upgrader, err error) {// {{{ +} // }}} + +// NewUpgrader creates an upgrader with an empty list of databases. +func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{ + // Using a variadic function for backward compatibility. + if len(schema) > 0 { + upgrader.schema = schema[0] + } else { + upgrader.schema = "_db" + } + upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return -}// }}} +} // }}} -func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ +// SetLogCallback allows to set a callback for custom logging. +func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{ upgrader.logCallback = callback -}// }}} -func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ +} // }}} +// SetSqlCallback is required for providing the SQL schema updates. +func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{ upgrader.sqlCallback = callback -}// }}} +} // }}} +// Version returns the current dbschema version for the given database name. +func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{ + dbase, found := upgrader.databases[dbName] + if !found { + err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) + return + } -func (dbase Database) verifySchemaTable() (err error) {// {{{ - var rows *sql.Rows + version, err = dbase.Version() + return +} // }}} + +func (dbase Database) createSchemaTable() (err error) { // {{{ + dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) + _, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`) + + // Error code 42P06 "duplicate_schema" is an OK error, + // table can still be missing and created. + pqErr, _ := err.(*pgconn.PgError) + if pqErr != nil && pqErr.Code != "42P06" { + return + } + + _, err = dbase.db.Exec( + context.Background(), + `CREATE TABLE "` + dbase.upgrader.schema + `"."schema" ( + version int4 NOT NULL, + updated timestamp NOT NULL DEFAULT NOW(), + + CONSTRAINT schema_pk PRIMARY KEY (version) + )`, + ) + return +} // }}} +func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{ + _, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) + return +} // }}} + +func (dbase Database) verifySchemaTable() (err error) { // {{{ + var rows pgx.Rows if rows, err = dbase.db.Query( + context.Background(), `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '_db' + WHERE n.nspname = '` + dbase.upgrader.schema + `' AND c.relname = 'schema' )`, ); err != nil { @@ -41,45 +96,30 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{ return } - if !exists { - dbase.upgrader.logCallback("create", fmt.Sprintf("%s, _db.schema", dbase.DbName)) - dbase.db.Exec(`CREATE SCHEMA "_db"`) - - if _, err = dbase.db.Exec(` - CREATE TABLE "_db"."schema" ( - version int4 NOT NULL, - updated timestamp NOT NULL DEFAULT NOW(), - - CONSTRAINT schema_pk PRIMARY KEY (version) - )`, - ); err != nil { - return - } - } - return -}// }}} -func (dbase Database) verifySchemaEntry() (err error) {// {{{ - var rows *sql.Rows - rows, err = dbase.db.Query(`SELECT version FROM _db.schema LIMIT 1`) - if err != nil { + if exists { return } - defer rows.Close() + err = dbase.createSchemaTable() + return +} // }}} +func (dbase Database) verifySchemaEntry() (err error) { // {{{ + var version int + var row pgx.Row + row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) - if !rows.Next() { + err = row.Scan(&version) + if err == pgx.ErrNoRows { dbase.upgrader.logCallback("initiate version", dbase.DbName) - _, err = dbase.db.Exec(`INSERT INTO _db.schema(version) VALUES(0)`) - if err != nil { - return - } + err = dbase.appendSchemaVersion(0) } return -}// }}} -func (dbase Database) version() (version int, err error) {// {{{ - var rows *sql.Rows +} // }}} +func (dbase Database) Version() (version int, err error) { // {{{ + var rows pgx.Rows rows, err = dbase.db.Query( - `SELECT version FROM _db.schema ORDER BY version DESC LIMIT 1`, + context.Background(), + `SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return @@ -89,18 +129,18 @@ func (dbase Database) version() (version int, err error) {// {{{ if rows.Next() { err = rows.Scan(&version) } else { - err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbase.DbName) + err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) } return -}// }}} +} // }}} -func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{ - var db Database +// AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. +func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{ if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } db.upgrader = &upgrader - + upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { @@ -109,16 +149,32 @@ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass s err = db.verifySchemaEntry() return -}// }}} -func (upgrader Upgrader) Run() (err error) {// {{{ +} // }}} +func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{ + db, err = databaseFromInstance(sqlDB) + + db.upgrader = &upgrader + + upgrader.databases[dbName] = db + + if err = db.verifySchemaTable(); err != nil { + return + } + + err = db.verifySchemaEntry() + return +} // }}} + +// Run executes the actual schema updates until there are no more available. +func (upgrader Upgrader) Run() (err error) { // {{{ var version int for dbName, dbase := range upgrader.databases { - version, err = dbase.version() + version, err = dbase.Version() if err != nil { return } - upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version)) + upgrader.logCallback("version", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) for { version++ @@ -127,20 +183,16 @@ func (upgrader Upgrader) Run() (err error) {// {{{ break } - upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version)) - if _, err = dbase.db.Exec(string(sql)); err != nil { + upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) + if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil { return } - _, err = dbase.db.Exec(` - INSERT INTO _db.schema(version) - VALUES($1) - `, version) - if err != nil { + if err = dbase.appendSchemaVersion(version); err != nil { return } } } return -}// }}} +} // }}} // vim: foldmethod=marker