diff --git a/database.go b/database.go index b756e9d..cd3eb52 100644 --- a/database.go +++ b/database.go @@ -1,11 +1,8 @@ package dbschema import ( - // External - "github.com/jackc/pgx/v5/pgxpool" - // Standard - "context" + "database/sql" "fmt" ) @@ -16,13 +13,9 @@ func newDatabase(host string, port int, dbName, user, pass string) (dbase Databa dbase.Username = user dbase.Password = pass - dbase.db, err = pgxpool.New(context.Background(), dbase.sqlConnString()) + dbase.db, err = sql.Open("postgres", dbase.sqlConnString()) return }// }}} -func databaseFromInstance(db *pgxpool.Pool) (dbase Database, err error) { - dbase.db = db - return -} func (dbase Database) sqlConnString() string {// {{{ return fmt.Sprintf( diff --git a/schema.go b/schema.go index 8ee3646..c7b5894 100644 --- a/schema.go +++ b/schema.go @@ -21,13 +21,15 @@ package dbschema import ( // External - "github.com/jackc/pgx/v5/pgxpool" + _ "github.com/lib/pq" + + // Standard + "database/sql" ) // An upgrader verifies the schema for one or more databases and upgrades them if possible. type Upgrader struct { - schema string - databases map[string]Database + databases map[string]Database logCallback func(string, string) sqlCallback func(string, int) ([]byte, bool) } @@ -39,7 +41,7 @@ type Database struct { Username string Password string - db *pgxpool.Pool + db *sql.DB upgrader *Upgrader } diff --git a/upgrader.go b/upgrader.go index c9670ab..7cb65f2 100644 --- a/upgrader.go +++ b/upgrader.go @@ -2,67 +2,57 @@ package dbschema import ( // External - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/lib/pq" // Standard - "context" + "database/sql" "fmt" ) -func defaultCallback(topic, msg string) { // {{{ +func defaultCallback(topic, msg string) {// {{{ fmt.Printf("[%s] %s\n", topic, msg) -} // }}} +}// }}} // NewUpgrader creates an upgrader with an empty list of databases. -func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{ - // Using a variadic function for backward compatibility. - if len(schema) > 0 { - upgrader.schema = schema[0] - } else { - upgrader.schema = "_db" - } - +func NewUpgrader() (upgrader Upgrader) {// {{{ upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return -} // }}} +}// }}} // SetLogCallback allows to set a callback for custom logging. -func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{ +func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ upgrader.logCallback = callback -} // }}} +}// }}} // SetSqlCallback is required for providing the SQL schema updates. -func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{ +func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ upgrader.sqlCallback = callback -} // }}} +}// }}} // Version returns the current dbschema version for the given database name. -func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{ +func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{ dbase, found := upgrader.databases[dbName] if !found { err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) return } - version, err = dbase.Version() + version, err = dbase.version() return -} // }}} +}// }}} -func (dbase Database) createSchemaTable() (err error) { // {{{ - dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) - _, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`) +func (dbase Database) createSchemaTable() (err error) {// {{{ + dbase.upgrader.logCallback("create", fmt.Sprintf("%s, _db.schema", dbase.DbName)) + _, err = dbase.db.Exec(`CREATE SCHEMA "_db"`) // Error code 42P06 "duplicate_schema" is an OK error, // table can still be missing and created. - pqErr, _ := err.(*pgconn.PgError) + pqErr, _ := err.(*pq.Error) if pqErr != nil && pqErr.Code != "42P06" { return } - _, err = dbase.db.Exec( - context.Background(), - `CREATE TABLE "` + dbase.upgrader.schema + `"."schema" ( + _, err = dbase.db.Exec(` + CREATE TABLE "_db"."schema" ( version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), @@ -70,20 +60,19 @@ func (dbase Database) createSchemaTable() (err error) { // {{{ )`, ) return -} // }}} -func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{ - _, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) +}// }}} +func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ + _, err = dbase.db.Exec(`INSERT INTO _db.schema(version) VALUES($1)`, version) return -} // }}} +}// }}} -func (dbase Database) verifySchemaTable() (err error) { // {{{ - var rows pgx.Rows +func (dbase Database) verifySchemaTable() (err error) {// {{{ + var rows *sql.Rows if rows, err = dbase.db.Query( - context.Background(), `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '` + dbase.upgrader.schema + `' + WHERE n.nspname = '_db' AND c.relname = 'schema' )`, ); err != nil { @@ -101,25 +90,24 @@ func (dbase Database) verifySchemaTable() (err error) { // {{{ } err = dbase.createSchemaTable() return -} // }}} -func (dbase Database) verifySchemaEntry() (err error) { // {{{ +}// }}} +func (dbase Database) verifySchemaEntry() (err error) {// {{{ var version int - var row pgx.Row - row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) + var row *sql.Row + row = dbase.db.QueryRow(`SELECT version FROM _db.schema LIMIT 1`) err = row.Scan(&version) - if err == pgx.ErrNoRows { + if err == sql.ErrNoRows { dbase.upgrader.logCallback("initiate version", dbase.DbName) err = dbase.appendSchemaVersion(0) } return -} // }}} -func (dbase Database) Version() (version int, err error) { // {{{ - var rows pgx.Rows +}// }}} +func (dbase Database) version() (version int, err error) {// {{{ + var rows *sql.Rows rows, err = dbase.db.Query( - context.Background(), - `SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`, + `SELECT version FROM _db.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return @@ -129,18 +117,19 @@ func (dbase Database) Version() (version int, err error) { // {{{ if rows.Next() { err = rows.Scan(&version) } else { - err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) + err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbase.DbName) } return -} // }}} +}// }}} // AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. -func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{ +func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{ + var db Database if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } db.upgrader = &upgrader - + upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { @@ -149,32 +138,17 @@ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass s err = db.verifySchemaEntry() return -} // }}} -func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{ - db, err = databaseFromInstance(sqlDB) - - db.upgrader = &upgrader - - upgrader.databases[dbName] = db - - if err = db.verifySchemaTable(); err != nil { - return - } - - err = db.verifySchemaEntry() - return -} // }}} - +}// }}} // Run executes the actual schema updates until there are no more available. -func (upgrader Upgrader) Run() (err error) { // {{{ +func (upgrader Upgrader) Run() (err error) {// {{{ var version int for dbName, dbase := range upgrader.databases { - version, err = dbase.Version() + version, err = dbase.version() if err != nil { return } - upgrader.logCallback("version", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) + upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version)) for { version++ @@ -183,8 +157,8 @@ func (upgrader Upgrader) Run() (err error) { // {{{ break } - upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) - if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil { + upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version)) + if _, err = dbase.db.Exec(string(sql)); err != nil { return } if err = dbase.appendSchemaVersion(version); err != nil { @@ -193,6 +167,6 @@ func (upgrader Upgrader) Run() (err error) { // {{{ } } return -} // }}} +}// }}} // vim: foldmethod=marker