diff --git a/db.go b/db.go new file mode 100644 index 0000000..c2db708 --- /dev/null +++ b/db.go @@ -0,0 +1,150 @@ +package main + +import ( + // External + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" + + // Standard + "errors" + "fmt" + "io/fs" + "regexp" + "strconv" +) + +var ( + dbConn string + db *sqlx.DB +) + +func dbInit() (err error) { // {{{ + dbConn = fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + config.Database.Host, + config.Database.Port, + config.Database.Username, + config.Database.Password, + config.Database.Name, + ) + + logger.Info("db", "op", "connect", "host", config.Database.Host, "port", config.Database.Port) + + if db, err = sqlx.Connect("postgres", dbConn); err != nil { + return + } + + if err = dbVerifyInternals(); err != nil { + return + } + + err = dbUpdate() + return +} // }}} +func dbVerifyInternals() (err error) { // {{{ + var rows *sqlx.Rows + if rows, err = db.Queryx( + `SELECT EXISTS ( + SELECT FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = '_internal' + AND c.relname = 'db' + )`, + ); err != nil { + return + } + defer rows.Close() + var exists bool + rows.Next() + if err = rows.Scan(&exists); err != nil { + return + } + + if !exists { + logger.Info("db", "op", "create_db", "db", "_internal.db") + if _, err = db.Exec(` + CREATE SCHEMA "_internal"; + + CREATE TABLE "_internal".db ( + "key" varchar NOT NULL, + value varchar NULL, + CONSTRAINT db_pk PRIMARY KEY (key) + ); + + INSERT INTO _internal.db("key", "value") + VALUES('schema', '0'); + `, + ); err != nil { + return + } + } + return +} // }}} +func dbUpdate() (err error) { // {{{ + /* Current schema revision is read from database. + * Used to iterate through the embedded SQL updates + * up to the db schema version currently compiled + * program is made for. */ + var rows *sqlx.Rows + var schemaStr string + var schema int + rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`) + if err != nil { + return + } + defer rows.Close() + + if !rows.Next() { + return errors.New("Table _interval.db missing schema row") + } + + if err = rows.Scan(&schemaStr); err != nil { + return + } + + // Run updates + schema, err = strconv.Atoi(schemaStr) + if err != nil { + return err + } + sqlSchemaVersion := sqlSchema() + for i := (schema + 1); i <= sqlSchemaVersion; i++ { + logger.Info("db", "op", "upgrade_schema", "schema", i) + sql, _ := embedded.ReadFile( + fmt.Sprintf("sql/%04d.sql", i), + ) + _, err = db.Exec(string(sql)) + if err != nil { + return + } + _, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i) + if err != nil { + return + } + logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok") + } + + return +} // }}} +func sqlSchema() (max int) { // {{{ + var num int + + files, _ := fs.ReadDir(embedded, "sql") + sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`) + + for _, file := range files { + fname := sqlFilename.FindStringSubmatch(file.Name()) + if len(fname) != 2 { + continue + } + num, _ = strconv.Atoi(fname[1]) + } + + if num > max { + max = num + } + + return +} // }}} + +// vim: foldmethod=marker diff --git a/file.go b/file.go index 0f8b5bc..0e297ca 100644 --- a/file.go +++ b/file.go @@ -5,6 +5,7 @@ import ( "github.com/jmoiron/sqlx" // Standard + "fmt" "time" ) @@ -19,11 +20,11 @@ type File struct { Uploaded time.Time } -func AddFile(userID int, file *File) (err error) { // {{{ - file.UserID = userID +func (session Session) AddFile(file *File) (err error) { // {{{ + file.UserID = session.UserID var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` INSERT INTO file(user_id, node_id, filename, size, mime, md5) VALUES($1, $2, $3, $4, $5, $6) RETURNING id @@ -42,11 +43,12 @@ func AddFile(userID int, file *File) (err error) { // {{{ rows.Next() err = rows.Scan(&file.ID) + fmt.Printf("%#v\n", file) return } // }}} -func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{ +func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{ var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx( + rows, err = db.Queryx( `SELECT * FROM file WHERE @@ -56,7 +58,7 @@ func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{ WHEN 0 THEN true ELSE id = $3 END`, - userID, + session.UserID, nodeID, fileID, ) diff --git a/key.go b/key.go index aff0f3a..708fe39 100644 --- a/key.go +++ b/key.go @@ -3,6 +3,8 @@ package main import ( // External "github.com/jmoiron/sqlx" + + // Standard ) type Key struct { @@ -12,9 +14,9 @@ type Key struct { Key string } -func Keys(userID int) (keys []Key, err error) { // {{{ +func (session Session) Keys() (keys []Key, err error) {// {{{ var rows *sqlx.Rows - if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil { + if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil { return } defer rows.Close() @@ -29,14 +31,14 @@ func Keys(userID int) (keys []Key, err error) { // {{{ } return -} // }}} -func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{ +}// }}} +func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{ var row *sqlx.Rows - if row, err = service.Db.Conn.Queryx( + if row, err = db.Queryx( `INSERT INTO crypto_key(user_id, description, key) VALUES($1, $2, $3) RETURNING *`, - userID, + session.UserID, description, keyEncoded, ); err != nil { @@ -52,10 +54,10 @@ func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) } return -} // }}} -func KeyCounter() (counter int64, err error) { // {{{ +}// }}} +func (session Session) KeyCounter() (counter int64, err error) {// {{{ var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`) + rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`) if err != nil { return } @@ -64,6 +66,6 @@ func KeyCounter() (counter int64, err error) { // {{{ rows.Next() err = rows.Scan(&counter) return -} // }}} +}// }}} // vim: foldmethod=marker diff --git a/main.go b/main.go index d2c95d1..6c0e611 100644 --- a/main.go +++ b/main.go @@ -1,23 +1,20 @@ package main import ( - // External - "git.gibonuddevalla.se/go/webservice" - - // Internal - "git.gibonuddevalla.se/go/webservice/session" - // Standard "crypto/md5" "embed" "encoding/hex" "flag" "fmt" + "html/template" "io" "log/slog" "net/http" "os" + "path" "path/filepath" + "regexp" "strconv" "strings" ) @@ -28,10 +25,8 @@ var ( flagPort int flagVersion bool flagCreateUser bool - flagCheckLocal bool flagConfig string - service *webservice.Service connectionManager ConnectionManager static http.Handler config Config @@ -39,24 +34,11 @@ var ( VERSION string //go:embed version sql/* - embeddedSQL embed.FS - - //go:embed static - staticFS embed.FS + embedded embed.FS ) -func sqlProvider(dbname string, version int) (sql []byte, found bool) { - var err error - sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version)) - if err != nil { - return - } - found = true - return -} - func init() { // {{{ - version, _ := embeddedSQL.ReadFile("version") + version, _ := embedded.ReadFile("version") VERSION = strings.TrimSpace(string(version)) opt := slog.HandlerOptions{} @@ -66,7 +48,6 @@ func init() { // {{{ flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on") flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists") flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit") - flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded") flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file") flag.Parse() } // }}} @@ -77,47 +58,56 @@ func main() { // {{{ fmt.Printf("%s\n", VERSION) os.Exit(0) } + logger.Info("application", "version", VERSION) - service, err = webservice.New(flagConfig, VERSION) + config, err = ConfigRead(flagConfig) if err != nil { - logger.Error("application", "error", err) + logger.Error("config", "error", err) os.Exit(1) } - service.SetDatabase(sqlProvider) - service.SetStaticDirectory("static", true) - service.SetStaticFS(staticFS, "static") - service.Register("/node/upload", true, true, nodeUpload) - service.Register("/node/tree", true, true, nodeTree) - service.Register("/node/retrieve", true, true, nodeRetrieve) - service.Register("/node/create", true, true, nodeCreate) - service.Register("/node/update", true, true, nodeUpdate) - service.Register("/node/rename", true, true, nodeRename) - service.Register("/node/delete", true, true, nodeDelete) - service.Register("/node/download", true, true, nodeDownload) - service.Register("/node/search", true, true, nodeSearch) - service.Register("/key/retrieve", true, true, keyRetrieve) - service.Register("/key/create", true, true, keyCreate) - service.Register("/key/counter", true, true, keyCounter) - service.Register("/ws", false, false, service.WebsocketHandler) - service.Register("/", false, false, service.StaticHandler) + if err = dbInit(); err != nil { + logger.Error("db", "error", err) + os.Exit(1) + } if flagCreateUser { - service.CreateUserPrompt() + err = createUser() + if err != nil { + logger.Error("db", "error", err) + os.Exit(1) + } os.Exit(0) } - err = service.Start() - if err != nil { - logger.Error("webserver", "error", err) - os.Exit(1) - } - connectionManager = NewConnectionManager() go connectionManager.BroadcastLoop() + static = http.FileServer(http.Dir(config.Application.Directories.Static)) http.HandleFunc("/css_updated", cssUpdateHandler) + http.HandleFunc("/session/create", sessionCreate) + http.HandleFunc("/session/retrieve", sessionRetrieve) + http.HandleFunc("/session/authenticate", sessionAuthenticate) + http.HandleFunc("/user/password", userPassword) + http.HandleFunc("/node/tree", nodeTree) + http.HandleFunc("/node/retrieve", nodeRetrieve) + http.HandleFunc("/node/create", nodeCreate) + http.HandleFunc("/node/update", nodeUpdate) + http.HandleFunc("/node/rename", nodeRename) + http.HandleFunc("/node/delete", nodeDelete) + http.HandleFunc("/node/upload", nodeUpload) + http.HandleFunc("/node/download", nodeDownload) + http.HandleFunc("/node/search", nodeSearch) + http.HandleFunc("/key/retrieve", keyRetrieve) + http.HandleFunc("/key/create", keyCreate) + http.HandleFunc("/key/counter", keyCounter) + http.HandleFunc("/ws", websocketHandler) + http.HandleFunc("/", staticHandler) + + listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort) + logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains) + http.ListenAndServe(listen, nil) } // }}} func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{ @@ -137,8 +127,108 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{ return } } // }}} +func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{ + data := struct { + VERSION string + }{ + VERSION: VERSION, + } + + // URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version. + // To get rid of problems with cached content in browser on a new version release, + // while also not disabling cache altogether. + logger.Debug("webserver", "request", r.URL.Path) + if r.URL.Path == "/favicon.ico" { + static.ServeHTTP(w, r) + return + } + + rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$") + if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil { + r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2]) + static.ServeHTTP(w, r) + return + } + + // Everything else is run through the template system. + // For now to get VERSION into files to fix caching. + logger.Debug("webserver", "template", r.URL.Path) + tmpl, err := newTemplate(r.URL.Path) + if err != nil { + if os.IsNotExist(err) { + w.WriteHeader(404) + } + w.Write([]byte(err.Error())) + return + } + + if err = tmpl.Execute(w, data); err != nil { + w.Write([]byte(err.Error())) + } +} // }}} + +func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{ + logger.Info("webserver", "request", "/session/create") + session, err := CreateSession() + if err != nil { + responseError(w, err) + return + } + responseData(w, map[string]interface{}{ + "OK": true, + "Session": session, + }) +} // }}} +func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ + logger.Info("webserver", "request", "/session/retrieve") + var err error + var found bool + var session Session + + if session, found, err = ValidateSession(r, false); err != nil { + responseError(w, err) + return + } + + responseData(w, map[string]interface{}{ + "OK": true, + "Valid": found, + "Session": session, + }) +} // }}} +func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{ + logger.Info("webserver", "request", "/session/authenticate") + var err error + var session Session + var authenticated bool + + // Validate session + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + + req := struct { + Username string + Password string + }{} + if err = parseRequest(r, &req); err != nil { + responseError(w, err) + return + } + + if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil { + responseError(w, err) + return + } + + responseData(w, map[string]interface{}{ + "OK": true, + "Authenticated": authenticated, + "Session": session, + }) +} // }}} -/* func userPassword(w http.ResponseWriter, r *http.Request) { // {{{ var err error var ok bool @@ -169,11 +259,15 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{ "CurrentPasswordOK": ok, }) } // }}} -*/ -func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ - logger.Info("webserver", "request", "/node/tree") +func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{ var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } req := struct{ StartNodeID int }{} if err = parseRequest(r, &req); err != nil { @@ -181,7 +275,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ return } - nodes, err := NodeTree(sess.UserID, req.StartNodeID) + nodes, err := session.NodeTree(req.StartNodeID) if err != nil { responseError(w, err) return @@ -192,9 +286,15 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ "Nodes": nodes, }) } // }}} -func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/retrieve") var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } req := struct{ ID int }{} if err = parseRequest(r, &req); err != nil { @@ -202,7 +302,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // return } - node, err := RetrieveNode(sess.UserID, req.ID) + node, err := session.Node(req.ID) if err != nil { responseError(w, err) return @@ -213,9 +313,15 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // "Node": node, }) } // }}} -func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/create") var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } req := struct { Name string @@ -226,7 +332,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ return } - node, err := CreateNode(sess.UserID, req.ParentID, req.Name) + node, err := session.CreateNode(req.ParentID, req.Name) if err != nil { responseError(w, err) return @@ -237,9 +343,15 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ "Node": node, }) } // }}} -func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/update") var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } req := struct { NodeID int @@ -251,7 +363,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ return } - err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID) + err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID) if err != nil { responseError(w, err) return @@ -261,10 +373,17 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ "OK": true, }) } // }}} -func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ - var err error +func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/rename") + var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + req := struct { NodeID int Name string @@ -274,7 +393,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ return } - err = RenameNode(sess.UserID, req.NodeID, req.Name) + err = session.RenameNode(req.NodeID, req.Name) if err != nil { responseError(w, err) return @@ -284,10 +403,17 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ "OK": true, }) } // }}} -func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ - var err error +func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/delete") + var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + req := struct { NodeID int }{} @@ -296,7 +422,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ return } - err = DeleteNode(sess.UserID, req.NodeID) + err = session.DeleteNode(req.NodeID) if err != nil { responseError(w, err) return @@ -306,9 +432,15 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ "OK": true, }) } // }}} -func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/upload") var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } // Parse our multipart form, 10 << 20 specifies a maximum // upload of 10 MB files. @@ -347,7 +479,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ MIME: handler.Header.Get("Content-Type"), MD5: md5sum, } - if err = AddFile(sess.UserID, &nodeFile); err != nil { + if err = session.AddFile(&nodeFile); err != nil { responseError(w, err) return } @@ -384,11 +516,17 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ "File": nodeFile, }) } // }}} -func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/download") var err error + var session Session var files []File + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + req := struct { NodeID int FileID int @@ -398,7 +536,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // return } - files, err = Files(sess.UserID, req.NodeID, req.FileID) + files, err = session.Files(req.NodeID, req.FileID) if err != nil { responseError(w, err) return @@ -445,11 +583,17 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // } } // }}} -func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/files") var err error + var session Session var files []File + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + req := struct { NodeID int }{} @@ -458,7 +602,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ return } - files, err = Files(sess.UserID, req.NodeID, 0) + files, err = session.Files(req.NodeID, 0) if err != nil { responseError(w, err) return @@ -469,11 +613,17 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ "Files": files, }) } // }}} -func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/node/search") var err error + var session Session var nodes []Node + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + req := struct { Search string }{} @@ -482,7 +632,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ return } - nodes, err = SearchNodes(sess.UserID, req.Search) + nodes, err = session.SearchNodes(req.Search) if err != nil { responseError(w, err) return @@ -494,11 +644,17 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ }) } // }}} -func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/key/retrieve") var err error + var session Session - keys, err := Keys(sess.UserID) + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + + keys, err := session.Keys() if err != nil { responseError(w, err) return @@ -509,9 +665,15 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // { "Keys": keys, }) } // }}} -func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/key/create") var err error + var session Session + + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } req := struct { Description string @@ -522,7 +684,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ return } - key, err := KeyCreate(sess.UserID, req.Description, req.Key) + key, err := session.KeyCreate(req.Description, req.Key) if err != nil { responseError(w, err) return @@ -533,11 +695,17 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ "Key": key, }) } // }}} -func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{ logger.Info("webserver", "request", "/key/counter") var err error + var session Session - counter, err := KeyCounter() + if session, _, err = ValidateSession(r, true); err != nil { + responseError(w, err) + return + } + + counter, err := session.KeyCounter() if err != nil { responseError(w, err) return @@ -550,4 +718,20 @@ func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{ }) } // }}} +func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{ + // Append index.html if needed for further reading of the file + p := requestPath + if p[len(p)-1] == '/' { + p += "index.html" + } + p = config.Application.Directories.Static + p + + base := path.Base(p) + if tmpl, err = template.New(base).ParseFiles(p); err != nil { + return + } + + return +} // }}} + // vim: foldmethod=marker diff --git a/node.go b/node.go index aacdc82..0a688ba 100644 --- a/node.go +++ b/node.go @@ -25,9 +25,9 @@ type Node struct { ContentEncrypted string `db:"content_encrypted" json:"-"` } -func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{ +func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` WITH RECURSIVE nodetree AS ( SELECT *, @@ -62,7 +62,7 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{ ORDER BY path ASC `, - userID, + session.UserID, startNodeID, ) if err != nil { @@ -79,9 +79,6 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{ for rows.Next() { node := Node{} node.Complete = false - node.Crumbs = []Node{} - node.Children = []Node{} - node.Files = []File{} if err = rows.StructScan(&node); err != nil { return } @@ -90,9 +87,9 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{ return }// }}} -func RootNode(userID int) (node Node, err error) {// {{{ +func (session Session) RootNode() (node Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` SELECT id, user_id, @@ -103,7 +100,7 @@ func RootNode(userID int) (node Node, err error) {// {{{ user_id = $1 AND parent_id IS NULL `, - userID, + session.UserID, ) if err != nil { return @@ -111,7 +108,7 @@ func RootNode(userID int) (node Node, err error) {// {{{ defer rows.Close() node.Name = "Start" - node.UserID = userID + node.UserID = session.UserID node.Complete = true node.Children = []Node{} node.Crumbs = []Node{} @@ -132,13 +129,13 @@ func RootNode(userID int) (node Node, err error) {// {{{ return }// }}} -func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{ +func (session Session) Node(nodeID int) (node Node, err error) {// {{{ if nodeID == 0 { - return RootNode(userID) + return session.RootNode() } var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` WITH RECURSIVE recurse AS ( SELECT id, @@ -173,7 +170,7 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{ SELECT * FROM recurse ORDER BY level ASC `, - userID, + session.UserID, nodeID, ) if err != nil { @@ -220,14 +217,14 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{ } } - node.Crumbs, err = NodeCrumbs(node.ID) - node.Files, err = Files(userID, node.ID, 0) + node.Crumbs, err = session.NodeCrumbs(node.ID) + node.Files, err = session.Files(node.ID, 0) return }// }}} -func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ +func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` WITH RECURSIVE nodes AS ( SELECT id, @@ -263,10 +260,10 @@ func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ } return }// }}} -func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{ +func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` INSERT INTO node(user_id, parent_id, name) VALUES($1, NULLIF($2, 0)::integer, $3) RETURNING @@ -276,7 +273,7 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{ name, content `, - userID, + session.UserID, parentID, name, ) @@ -295,12 +292,12 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{ node.Complete = true } - node.Crumbs, err = NodeCrumbs(node.ID) + node.Crumbs, err = session.NodeCrumbs(node.ID) return }// }}} -func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{ +func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{ if cryptoKeyID > 0 { - _, err = service.Db.Conn.Exec(` + _, err = db.Exec(` UPDATE node SET content = '', @@ -316,10 +313,10 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) content, cryptoKeyID, nodeID, - userID, + session.UserID, ) } else { - _, err = service.Db.Conn.Exec(` + _, err = db.Exec(` UPDATE node SET content = $1, @@ -335,24 +332,24 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) content, cryptoKeyID, nodeID, - userID, + session.UserID, ) } return }// }}} -func RenameNode(userID, nodeID int, name string) (err error) {// {{{ - _, err = service.Db.Conn.Exec(` +func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{ + _, err = db.Exec(` UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3 `, name, - userID, + session.UserID, nodeID, ) return }// }}} -func DeleteNode(userID, nodeID int) (err error) {// {{{ - _, err = service.Db.Conn.Exec(` +func (session Session) DeleteNode(nodeID int) (err error) {// {{{ + _, err = db.Exec(` WITH RECURSIVE nodetree AS ( SELECT id, parent_id @@ -371,15 +368,15 @@ func DeleteNode(userID, nodeID int) (err error) {// {{{ DELETE FROM node WHERE id IN ( SELECT id FROM nodetree )`, - userID, + session.UserID, nodeID, ) return }// }}} -func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{ +func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{ nodes = []Node{} var rows *sqlx.Rows - rows, err = service.Db.Conn.Queryx(` + rows, err = db.Queryx(` SELECT id, user_id, @@ -388,15 +385,14 @@ func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{ updated FROM node WHERE - user_id = $1 AND crypto_key_id IS NULL AND ( - content ~* $2 OR - name ~* $2 + content ~* $1 OR + name ~* $1 ) ORDER BY updated DESC - `, userID, search) + `, search) if err != nil { return } diff --git a/session.go b/session.go new file mode 100644 index 0000000..53d163e --- /dev/null +++ b/session.go @@ -0,0 +1,119 @@ +package main + +import ( + // Standard + "database/sql" + "errors" + "fmt" + "net/http" + "time" +) + +type Session struct { + UUID string + UserID int + Created time.Time +} + +func CreateSession() (session Session, err error) {// {{{ + var rows *sql.Rows + if rows, err = db.Query(` + INSERT INTO public.session(uuid) + VALUES(gen_random_uuid()) + RETURNING uuid, created`, + ); err != nil { + return + } + defer rows.Close() + + if rows.Next() { + rows.Scan(&session.UUID, &session.Created) + } + + return +}// }}} + +func sessionUUID(r *http.Request) (string, error) {// {{{ + headers := r.Header["X-Session-Id"] + if len(headers) > 0 { + return headers[0], nil + } + return "", errors.New("Invalid session") +}// }}} +func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{ + var uuid string + if uuid, err = sessionUUID(r); err != nil { + return + } + + session.UUID = uuid + if found, err = session.Retrieve(); err != nil { + return + } + + if notFoundIsError && !found { + err = errors.New("Invalid session") + return + } + + return +}// }}} + +func (session *Session) Retrieve() (found bool, err error) {// {{{ + var rows *sql.Rows + if rows, err = db.Query(` + SELECT + uuid, user_id, created + FROM public.session + WHERE + uuid = $1 AND + created + $2::interval >= NOW() + `, + session.UUID, + fmt.Sprintf("%d days", config.Session.DaysValid), + ); err != nil { + return + } + defer rows.Close() + + found = false + if rows.Next() { + found = true + rows.Scan(&session.UUID, &session.UserID, &session.Created) + } + + return +}// }}} +func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{ + var rows *sql.Rows + if rows, err = db.Query(` + SELECT id + FROM public.user + WHERE + username=$1 AND + password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea) + `, + username, + password, + ); err != nil { + return + } + defer rows.Close() + + if rows.Next() { + rows.Scan(&session.UserID) + authenticated = session.UserID > 0 + } + + if authenticated { + _, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID) + if err != nil { + return + } + } + + return +}// }}} + + +// vim: foldmethod=marker diff --git a/sql/00001.sql b/sql/0001.sql similarity index 100% rename from sql/00001.sql rename to sql/0001.sql diff --git a/sql/00002.sql b/sql/0002.sql similarity index 100% rename from sql/00002.sql rename to sql/0002.sql diff --git a/sql/00003.sql b/sql/0003.sql similarity index 100% rename from sql/00003.sql rename to sql/0003.sql diff --git a/sql/00004.sql b/sql/0004.sql similarity index 100% rename from sql/00004.sql rename to sql/0004.sql diff --git a/sql/00005.sql b/sql/0005.sql similarity index 100% rename from sql/00005.sql rename to sql/0005.sql diff --git a/sql/00006.sql b/sql/0006.sql similarity index 100% rename from sql/00006.sql rename to sql/0006.sql diff --git a/sql/00007.sql b/sql/0007.sql similarity index 100% rename from sql/00007.sql rename to sql/0007.sql diff --git a/sql/00008.sql b/sql/0008.sql similarity index 100% rename from sql/00008.sql rename to sql/0008.sql diff --git a/sql/00009.sql b/sql/0009.sql similarity index 100% rename from sql/00009.sql rename to sql/0009.sql diff --git a/sql/00010.sql b/sql/0010.sql similarity index 100% rename from sql/00010.sql rename to sql/0010.sql diff --git a/sql/00011.sql b/sql/0011.sql similarity index 100% rename from sql/00011.sql rename to sql/0011.sql diff --git a/sql/00012.sql b/sql/0012.sql similarity index 100% rename from sql/00012.sql rename to sql/0012.sql diff --git a/sql/0013.sql b/sql/0013.sql new file mode 100644 index 0000000..d8fb23d --- /dev/null +++ b/sql/0013.sql @@ -0,0 +1,34 @@ +/* Required for the gen_random_bytes function */ +CREATE EXTENSION pgcrypto; + +CREATE FUNCTION password_hash(salt_hex char(32), pass bytea) +RETURNS char(96) +LANGUAGE plpgsql +AS +$$ +BEGIN + RETURN ( + SELECT + salt_hex || + encode( + sha256( + decode(salt_hex, 'hex') || /* salt in binary */ + pass /* password */ + ), + 'hex' + ) + ); +END; +$$; + +/* Password has to be able to accommodate 96 characters instead of previous 64. + * It can't be char(96), because then the password would be padded to 96 characters. */ +ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar; + +/* Update all users with salted and hashed passwords */ +UPDATE public.user +SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea); + +/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash, + * and the varchar type is not longer necessary. */ +ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar; diff --git a/static/js/app.mjs b/static/js/app.mjs index 0176d1a..ecb641f 100644 --- a/static/js/app.mjs +++ b/static/js/app.mjs @@ -13,8 +13,8 @@ class App extends Component { this.websocket = null this.websocket_int_ping = null this.websocket_int_reconnect = null - //this.wsConnect() // XXX - //this.wsLoop() // XXX + this.wsConnect() + this.wsLoop() this.session = new Session(this) this.session.initialize() @@ -114,7 +114,7 @@ class App extends Component { } if(this.session.UUID !== '') - headers['X-Session-ID'] = this.session.UUID + headers['X-Session-Id'] = this.session.UUID fetch(url, { method: 'POST', @@ -201,24 +201,9 @@ class Tree extends Component { this.selectedTreeNode = null this.props.app.tree = this - this.retrieve() - }//}}} - render({ app }) {//{{{ - let renderedTreeTrunk = this.treeTrunk.map(node=>{ - this.treeNodeComponents[node.ID] = createRef() - return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />` - }) - return html`