diff --git a/database.go b/database.go index cd3eb52..b756e9d 100644 --- a/database.go +++ b/database.go @@ -1,8 +1,11 @@ package dbschema import ( + // External + "github.com/jackc/pgx/v5/pgxpool" + // Standard - "database/sql" + "context" "fmt" ) @@ -13,9 +16,13 @@ func newDatabase(host string, port int, dbName, user, pass string) (dbase Databa dbase.Username = user dbase.Password = pass - dbase.db, err = sql.Open("postgres", dbase.sqlConnString()) + dbase.db, err = pgxpool.New(context.Background(), dbase.sqlConnString()) return }// }}} +func databaseFromInstance(db *pgxpool.Pool) (dbase Database, err error) { + dbase.db = db + return +} func (dbase Database) sqlConnString() string {// {{{ return fmt.Sprintf( diff --git a/schema.go b/schema.go index c7b5894..8ee3646 100644 --- a/schema.go +++ b/schema.go @@ -21,15 +21,13 @@ package dbschema import ( // External - _ "github.com/lib/pq" - - // Standard - "database/sql" + "github.com/jackc/pgx/v5/pgxpool" ) // An upgrader verifies the schema for one or more databases and upgrades them if possible. type Upgrader struct { - databases map[string]Database + schema string + databases map[string]Database logCallback func(string, string) sqlCallback func(string, int) ([]byte, bool) } @@ -41,7 +39,7 @@ type Database struct { Username string Password string - db *sql.DB + db *pgxpool.Pool upgrader *Upgrader } diff --git a/upgrader.go b/upgrader.go index 7cb65f2..c9670ab 100644 --- a/upgrader.go +++ b/upgrader.go @@ -2,57 +2,67 @@ package dbschema import ( // External - "github.com/lib/pq" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" // Standard - "database/sql" + "context" "fmt" ) -func defaultCallback(topic, msg string) {// {{{ +func defaultCallback(topic, msg string) { // {{{ fmt.Printf("[%s] %s\n", topic, msg) -}// }}} +} // }}} // NewUpgrader creates an upgrader with an empty list of databases. -func NewUpgrader() (upgrader Upgrader) {// {{{ +func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{ + // Using a variadic function for backward compatibility. + if len(schema) > 0 { + upgrader.schema = schema[0] + } else { + upgrader.schema = "_db" + } + upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return -}// }}} +} // }}} // SetLogCallback allows to set a callback for custom logging. -func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ +func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{ upgrader.logCallback = callback -}// }}} +} // }}} // SetSqlCallback is required for providing the SQL schema updates. -func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ +func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{ upgrader.sqlCallback = callback -}// }}} +} // }}} // Version returns the current dbschema version for the given database name. -func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{ +func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{ dbase, found := upgrader.databases[dbName] if !found { err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) return } - version, err = dbase.version() + version, err = dbase.Version() return -}// }}} +} // }}} -func (dbase Database) createSchemaTable() (err error) {// {{{ - dbase.upgrader.logCallback("create", fmt.Sprintf("%s, _db.schema", dbase.DbName)) - _, err = dbase.db.Exec(`CREATE SCHEMA "_db"`) +func (dbase Database) createSchemaTable() (err error) { // {{{ + dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) + _, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`) // Error code 42P06 "duplicate_schema" is an OK error, // table can still be missing and created. - pqErr, _ := err.(*pq.Error) + pqErr, _ := err.(*pgconn.PgError) if pqErr != nil && pqErr.Code != "42P06" { return } - _, err = dbase.db.Exec(` - CREATE TABLE "_db"."schema" ( + _, err = dbase.db.Exec( + context.Background(), + `CREATE TABLE "` + dbase.upgrader.schema + `"."schema" ( version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), @@ -60,19 +70,20 @@ func (dbase Database) createSchemaTable() (err error) {// {{{ )`, ) return -}// }}} -func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ - _, err = dbase.db.Exec(`INSERT INTO _db.schema(version) VALUES($1)`, version) +} // }}} +func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{ + _, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) return -}// }}} +} // }}} -func (dbase Database) verifySchemaTable() (err error) {// {{{ - var rows *sql.Rows +func (dbase Database) verifySchemaTable() (err error) { // {{{ + var rows pgx.Rows if rows, err = dbase.db.Query( + context.Background(), `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '_db' + WHERE n.nspname = '` + dbase.upgrader.schema + `' AND c.relname = 'schema' )`, ); err != nil { @@ -90,24 +101,25 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{ } err = dbase.createSchemaTable() return -}// }}} -func (dbase Database) verifySchemaEntry() (err error) {// {{{ +} // }}} +func (dbase Database) verifySchemaEntry() (err error) { // {{{ var version int - var row *sql.Row - row = dbase.db.QueryRow(`SELECT version FROM _db.schema LIMIT 1`) + var row pgx.Row + row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) err = row.Scan(&version) - if err == sql.ErrNoRows { + if err == pgx.ErrNoRows { dbase.upgrader.logCallback("initiate version", dbase.DbName) err = dbase.appendSchemaVersion(0) } return -}// }}} -func (dbase Database) version() (version int, err error) {// {{{ - var rows *sql.Rows +} // }}} +func (dbase Database) Version() (version int, err error) { // {{{ + var rows pgx.Rows rows, err = dbase.db.Query( - `SELECT version FROM _db.schema ORDER BY version DESC LIMIT 1`, + context.Background(), + `SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return @@ -117,19 +129,18 @@ func (dbase Database) version() (version int, err error) {// {{{ if rows.Next() { err = rows.Scan(&version) } else { - err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbase.DbName) + err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) } return -}// }}} +} // }}} // AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. -func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{ - var db Database +func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{ if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } db.upgrader = &upgrader - + upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { @@ -138,17 +149,32 @@ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass s err = db.verifySchemaEntry() return -}// }}} +} // }}} +func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{ + db, err = databaseFromInstance(sqlDB) + + db.upgrader = &upgrader + + upgrader.databases[dbName] = db + + if err = db.verifySchemaTable(); err != nil { + return + } + + err = db.verifySchemaEntry() + return +} // }}} + // Run executes the actual schema updates until there are no more available. -func (upgrader Upgrader) Run() (err error) {// {{{ +func (upgrader Upgrader) Run() (err error) { // {{{ var version int for dbName, dbase := range upgrader.databases { - version, err = dbase.version() + version, err = dbase.Version() if err != nil { return } - upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version)) + upgrader.logCallback("version", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) for { version++ @@ -157,8 +183,8 @@ func (upgrader Upgrader) Run() (err error) {// {{{ break } - upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version)) - if _, err = dbase.db.Exec(string(sql)); err != nil { + upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) + if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil { return } if err = dbase.appendSchemaVersion(version); err != nil { @@ -167,6 +193,6 @@ func (upgrader Upgrader) Run() (err error) {// {{{ } } return -}// }}} +} // }}} // vim: foldmethod=marker