diff --git a/db.go b/db.go index c2db708..1e3d21a 100644 --- a/db.go +++ b/db.go @@ -6,11 +6,7 @@ import ( _ "github.com/lib/pq" // Standard - "errors" "fmt" - "io/fs" - "regexp" - "strconv" ) var ( @@ -33,117 +29,6 @@ func dbInit() (err error) { // {{{ if db, err = sqlx.Connect("postgres", dbConn); err != nil { return } - - if err = dbVerifyInternals(); err != nil { - return - } - - err = dbUpdate() - return -} // }}} -func dbVerifyInternals() (err error) { // {{{ - var rows *sqlx.Rows - if rows, err = db.Queryx( - `SELECT EXISTS ( - SELECT FROM pg_catalog.pg_class c - JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '_internal' - AND c.relname = 'db' - )`, - ); err != nil { - return - } - defer rows.Close() - var exists bool - rows.Next() - if err = rows.Scan(&exists); err != nil { - return - } - - if !exists { - logger.Info("db", "op", "create_db", "db", "_internal.db") - if _, err = db.Exec(` - CREATE SCHEMA "_internal"; - - CREATE TABLE "_internal".db ( - "key" varchar NOT NULL, - value varchar NULL, - CONSTRAINT db_pk PRIMARY KEY (key) - ); - - INSERT INTO _internal.db("key", "value") - VALUES('schema', '0'); - `, - ); err != nil { - return - } - } - return -} // }}} -func dbUpdate() (err error) { // {{{ - /* Current schema revision is read from database. - * Used to iterate through the embedded SQL updates - * up to the db schema version currently compiled - * program is made for. */ - var rows *sqlx.Rows - var schemaStr string - var schema int - rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`) - if err != nil { - return - } - defer rows.Close() - - if !rows.Next() { - return errors.New("Table _interval.db missing schema row") - } - - if err = rows.Scan(&schemaStr); err != nil { - return - } - - // Run updates - schema, err = strconv.Atoi(schemaStr) - if err != nil { - return err - } - sqlSchemaVersion := sqlSchema() - for i := (schema + 1); i <= sqlSchemaVersion; i++ { - logger.Info("db", "op", "upgrade_schema", "schema", i) - sql, _ := embedded.ReadFile( - fmt.Sprintf("sql/%04d.sql", i), - ) - _, err = db.Exec(string(sql)) - if err != nil { - return - } - _, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i) - if err != nil { - return - } - logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok") - } - - return -} // }}} -func sqlSchema() (max int) { // {{{ - var num int - - files, _ := fs.ReadDir(embedded, "sql") - sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`) - - for _, file := range files { - fname := sqlFilename.FindStringSubmatch(file.Name()) - if len(fname) != 2 { - continue - } - num, _ = strconv.Atoi(fname[1]) - } - - if num > max { - max = num - } - return } // }}} diff --git a/file.go b/file.go index 0e297ca..0f8b5bc 100644 --- a/file.go +++ b/file.go @@ -5,7 +5,6 @@ import ( "github.com/jmoiron/sqlx" // Standard - "fmt" "time" ) @@ -20,11 +19,11 @@ type File struct { Uploaded time.Time } -func (session Session) AddFile(file *File) (err error) { // {{{ - file.UserID = session.UserID +func AddFile(userID int, file *File) (err error) { // {{{ + file.UserID = userID var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` INSERT INTO file(user_id, node_id, filename, size, mime, md5) VALUES($1, $2, $3, $4, $5, $6) RETURNING id @@ -43,12 +42,11 @@ func (session Session) AddFile(file *File) (err error) { // {{{ rows.Next() err = rows.Scan(&file.ID) - fmt.Printf("%#v\n", file) return } // }}} -func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{ +func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{ var rows *sqlx.Rows - rows, err = db.Queryx( + rows, err = service.Db.Conn.Queryx( `SELECT * FROM file WHERE @@ -58,7 +56,7 @@ func (session Session) Files(nodeID, fileID int) (files []File, err error) { // WHEN 0 THEN true ELSE id = $3 END`, - session.UserID, + userID, nodeID, fileID, ) diff --git a/key.go b/key.go index 708fe39..aff0f3a 100644 --- a/key.go +++ b/key.go @@ -3,8 +3,6 @@ package main import ( // External "github.com/jmoiron/sqlx" - - // Standard ) type Key struct { @@ -14,9 +12,9 @@ type Key struct { Key string } -func (session Session) Keys() (keys []Key, err error) {// {{{ +func Keys(userID int) (keys []Key, err error) { // {{{ var rows *sqlx.Rows - if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil { + if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil { return } defer rows.Close() @@ -31,14 +29,14 @@ func (session Session) Keys() (keys []Key, err error) {// {{{ } return -}// }}} -func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{ +} // }}} +func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{ var row *sqlx.Rows - if row, err = db.Queryx( + if row, err = service.Db.Conn.Queryx( `INSERT INTO crypto_key(user_id, description, key) VALUES($1, $2, $3) RETURNING *`, - session.UserID, + userID, description, keyEncoded, ); err != nil { @@ -54,10 +52,10 @@ func (session Session) KeyCreate(description, keyEncoded string) (key Key, err e } return -}// }}} -func (session Session) KeyCounter() (counter int64, err error) {// {{{ +} // }}} +func KeyCounter() (counter int64, err error) { // {{{ var rows *sqlx.Rows - rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`) + rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`) if err != nil { return } @@ -66,6 +64,6 @@ func (session Session) KeyCounter() (counter int64, err error) {// {{{ rows.Next() err = rows.Scan(&counter) return -}// }}} +} // }}} // vim: foldmethod=marker diff --git a/main.go b/main.go index 6c0e611..7d4e1e3 100644 --- a/main.go +++ b/main.go @@ -1,20 +1,23 @@ package main import ( + // External + "git.gibonuddevalla.se/go/webservice" + + // Internal + "git.gibonuddevalla.se/go/webservice/session" + // Standard "crypto/md5" "embed" "encoding/hex" "flag" "fmt" - "html/template" "io" "log/slog" "net/http" "os" - "path" "path/filepath" - "regexp" "strconv" "strings" ) @@ -25,8 +28,10 @@ var ( flagPort int flagVersion bool flagCreateUser bool + flagCheckLocal bool flagConfig string + service *webservice.Service connectionManager ConnectionManager static http.Handler config Config @@ -34,11 +39,24 @@ var ( VERSION string //go:embed version sql/* - embedded embed.FS + embeddedSQL embed.FS + + //go:embed static + staticFS embed.FS ) +func sqlProvider(dbname string, version int) (sql []byte, found bool) { + var err error + sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version)) + if err != nil { + return + } + found = true + return +} + func init() { // {{{ - version, _ := embedded.ReadFile("version") + version, _ := embeddedSQL.ReadFile("version") VERSION = strings.TrimSpace(string(version)) opt := slog.HandlerOptions{} @@ -48,6 +66,7 @@ func init() { // {{{ flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on") flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists") flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit") + flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded") flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file") flag.Parse() } // }}} @@ -58,56 +77,47 @@ func main() { // {{{ fmt.Printf("%s\n", VERSION) os.Exit(0) } - logger.Info("application", "version", VERSION) - config, err = ConfigRead(flagConfig) + service, err = webservice.New(flagConfig, VERSION) if err != nil { - logger.Error("config", "error", err) + logger.Error("application", "error", err) os.Exit(1) } + service.SetDatabase(sqlProvider) + service.SetStaticDirectory("static", true) + service.SetStaticFS(staticFS, "static") - if err = dbInit(); err != nil { - logger.Error("db", "error", err) - os.Exit(1) - } + service.Register("/node/upload", true, true, nodeUpload) + service.Register("/node/tree", true, true, nodeTree) + service.Register("/node/retrieve", true, true, nodeRetrieve) + service.Register("/node/create", true, true, nodeCreate) + service.Register("/node/update", true, true, nodeUpdate) + service.Register("/node/rename", true, true, nodeRename) + service.Register("/node/delete", true, true, nodeDelete) + service.Register("/node/download", true, true, nodeDownload) + service.Register("/node/search", true, true, nodeSearch) + service.Register("/key/retrieve", true, true, keyRetrieve) + service.Register("/key/create", true, true, keyCreate) + service.Register("/key/counter", true, true, keyCounter) + service.Register("/ws", false, false, service.WebsocketHandler) + service.Register("/", false, false, service.StaticHandler) if flagCreateUser { - err = createUser() - if err != nil { - logger.Error("db", "error", err) - os.Exit(1) - } + service.CreateUserPrompt() os.Exit(0) } + err = service.Start() + if err != nil { + logger.Error("webserver", "error", err) + os.Exit(1) + } + connectionManager = NewConnectionManager() go connectionManager.BroadcastLoop() - static = http.FileServer(http.Dir(config.Application.Directories.Static)) http.HandleFunc("/css_updated", cssUpdateHandler) - http.HandleFunc("/session/create", sessionCreate) - http.HandleFunc("/session/retrieve", sessionRetrieve) - http.HandleFunc("/session/authenticate", sessionAuthenticate) - http.HandleFunc("/user/password", userPassword) - http.HandleFunc("/node/tree", nodeTree) - http.HandleFunc("/node/retrieve", nodeRetrieve) - http.HandleFunc("/node/create", nodeCreate) - http.HandleFunc("/node/update", nodeUpdate) - http.HandleFunc("/node/rename", nodeRename) - http.HandleFunc("/node/delete", nodeDelete) - http.HandleFunc("/node/upload", nodeUpload) - http.HandleFunc("/node/download", nodeDownload) - http.HandleFunc("/node/search", nodeSearch) - http.HandleFunc("/key/retrieve", keyRetrieve) - http.HandleFunc("/key/create", keyCreate) - http.HandleFunc("/key/counter", keyCounter) - http.HandleFunc("/ws", websocketHandler) - http.HandleFunc("/", staticHandler) - - listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort) - logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains) - http.ListenAndServe(listen, nil) } // }}} func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{ @@ -127,108 +137,8 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{ return } } // }}} -func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{ - data := struct { - VERSION string - }{ - VERSION: VERSION, - } - - // URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version. - // To get rid of problems with cached content in browser on a new version release, - // while also not disabling cache altogether. - logger.Debug("webserver", "request", r.URL.Path) - if r.URL.Path == "/favicon.ico" { - static.ServeHTTP(w, r) - return - } - - rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$") - if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil { - r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2]) - static.ServeHTTP(w, r) - return - } - - // Everything else is run through the template system. - // For now to get VERSION into files to fix caching. - logger.Debug("webserver", "template", r.URL.Path) - tmpl, err := newTemplate(r.URL.Path) - if err != nil { - if os.IsNotExist(err) { - w.WriteHeader(404) - } - w.Write([]byte(err.Error())) - return - } - - if err = tmpl.Execute(w, data); err != nil { - w.Write([]byte(err.Error())) - } -} // }}} - -func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{ - logger.Info("webserver", "request", "/session/create") - session, err := CreateSession() - if err != nil { - responseError(w, err) - return - } - responseData(w, map[string]interface{}{ - "OK": true, - "Session": session, - }) -} // }}} -func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ - logger.Info("webserver", "request", "/session/retrieve") - var err error - var found bool - var session Session - - if session, found, err = ValidateSession(r, false); err != nil { - responseError(w, err) - return - } - - responseData(w, map[string]interface{}{ - "OK": true, - "Valid": found, - "Session": session, - }) -} // }}} -func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{ - logger.Info("webserver", "request", "/session/authenticate") - var err error - var session Session - var authenticated bool - - // Validate session - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } - - req := struct { - Username string - Password string - }{} - if err = parseRequest(r, &req); err != nil { - responseError(w, err) - return - } - - if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil { - responseError(w, err) - return - } - - responseData(w, map[string]interface{}{ - "OK": true, - "Authenticated": authenticated, - "Session": session, - }) -} // }}} +/* func userPassword(w http.ResponseWriter, r *http.Request) { // {{{ var err error var ok bool @@ -259,15 +169,10 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{ "CurrentPasswordOK": ok, }) } // }}} +*/ -func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } req := struct{ StartNodeID int }{} if err = parseRequest(r, &req); err != nil { @@ -275,7 +180,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{ return } - nodes, err := session.NodeTree(req.StartNodeID) + nodes, err := NodeTree(sess.UserID, req.StartNodeID) if err != nil { responseError(w, err) return @@ -286,15 +191,9 @@ func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{ "Nodes": nodes, }) } // }}} -func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ - logger.Info("webserver", "request", "/node/retrieve") +func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } + logger.Info("webserver", "request", "/node/retrieve") req := struct{ ID int }{} if err = parseRequest(r, &req); err != nil { @@ -302,7 +201,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ return } - node, err := session.Node(req.ID) + node, err := RetrieveNode(sess.UserID, req.ID) if err != nil { responseError(w, err) return @@ -313,15 +212,9 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ "Node": node, }) } // }}} -func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/node/create") var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } req := struct { Name string @@ -332,7 +225,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{ return } - node, err := session.CreateNode(req.ParentID, req.Name) + node, err := CreateNode(sess.UserID, req.ParentID, req.Name) if err != nil { responseError(w, err) return @@ -343,15 +236,9 @@ func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{ "Node": node, }) } // }}} -func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/node/update") var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } req := struct { NodeID int @@ -363,7 +250,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{ return } - err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID) + err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID) if err != nil { responseError(w, err) return @@ -373,16 +260,9 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{ "OK": true, }) } // }}} -func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{ - logger.Info("webserver", "request", "/node/rename") - +func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } + logger.Info("webserver", "request", "/node/rename") req := struct { NodeID int @@ -393,7 +273,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{ return } - err = session.RenameNode(req.NodeID, req.Name) + err = RenameNode(sess.UserID, req.NodeID, req.Name) if err != nil { responseError(w, err) return @@ -403,16 +283,9 @@ func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{ "OK": true, }) } // }}} -func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{ - logger.Info("webserver", "request", "/node/delete") - +func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } + logger.Info("webserver", "request", "/node/delete") req := struct { NodeID int @@ -422,7 +295,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{ return } - err = session.DeleteNode(req.NodeID) + err = DeleteNode(sess.UserID, req.NodeID) if err != nil { responseError(w, err) return @@ -432,15 +305,9 @@ func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{ "OK": true, }) } // }}} -func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/node/upload") var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } // Parse our multipart form, 10 << 20 specifies a maximum // upload of 10 MB files. @@ -479,7 +346,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{ MIME: handler.Header.Get("Content-Type"), MD5: md5sum, } - if err = session.AddFile(&nodeFile); err != nil { + if err = AddFile(sess.UserID, &nodeFile); err != nil { responseError(w, err) return } @@ -516,17 +383,11 @@ func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{ "File": nodeFile, }) } // }}} -func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/node/download") var err error - var session Session var files []File - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } - req := struct { NodeID int FileID int @@ -536,7 +397,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{ return } - files, err = session.Files(req.NodeID, req.FileID) + files, err = Files(sess.UserID, req.NodeID, req.FileID) if err != nil { responseError(w, err) return @@ -583,17 +444,11 @@ func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{ } } // }}} -func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/node/files") var err error - var session Session var files []File - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } - req := struct { NodeID int }{} @@ -602,7 +457,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{ return } - files, err = session.Files(req.NodeID, 0) + files, err = Files(sess.UserID, req.NodeID, 0) if err != nil { responseError(w, err) return @@ -613,17 +468,11 @@ func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{ "Files": files, }) } // }}} -func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{ +func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/node/search") var err error - var session Session var nodes []Node - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } - req := struct { Search string }{} @@ -632,7 +481,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{ return } - nodes, err = session.SearchNodes(req.Search) + nodes, err = SearchNodes(sess.UserID, req.Search) if err != nil { responseError(w, err) return @@ -644,17 +493,11 @@ func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{ }) } // }}} -func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ +func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/key/retrieve") var err error - var session Session - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } - - keys, err := session.Keys() + keys, err := Keys(sess.UserID) if err != nil { responseError(w, err) return @@ -665,15 +508,9 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ "Keys": keys, }) } // }}} -func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{ +func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/key/create") var err error - var session Session - - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } req := struct { Description string @@ -684,7 +521,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{ return } - key, err := session.KeyCreate(req.Description, req.Key) + key, err := KeyCreate(sess.UserID, req.Description, req.Key) if err != nil { responseError(w, err) return @@ -695,17 +532,11 @@ func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{ "Key": key, }) } // }}} -func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{ +func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ logger.Info("webserver", "request", "/key/counter") var err error - var session Session - if session, _, err = ValidateSession(r, true); err != nil { - responseError(w, err) - return - } - - counter, err := session.KeyCounter() + counter, err := KeyCounter() if err != nil { responseError(w, err) return @@ -718,20 +549,4 @@ func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{ }) } // }}} -func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{ - // Append index.html if needed for further reading of the file - p := requestPath - if p[len(p)-1] == '/' { - p += "index.html" - } - p = config.Application.Directories.Static + p - - base := path.Base(p) - if tmpl, err = template.New(base).ParseFiles(p); err != nil { - return - } - - return -} // }}} - // vim: foldmethod=marker diff --git a/node.go b/node.go index 0a688ba..2ebf876 100644 --- a/node.go +++ b/node.go @@ -25,9 +25,9 @@ type Node struct { ContentEncrypted string `db:"content_encrypted" json:"-"` } -func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{ +func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` WITH RECURSIVE nodetree AS ( SELECT *, @@ -62,7 +62,7 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// { ORDER BY path ASC `, - session.UserID, + userID, startNodeID, ) if err != nil { @@ -87,9 +87,9 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// { return }// }}} -func (session Session) RootNode() (node Node, err error) {// {{{ +func RootNode(userID int) (node Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` SELECT id, user_id, @@ -100,7 +100,7 @@ func (session Session) RootNode() (node Node, err error) {// {{{ user_id = $1 AND parent_id IS NULL `, - session.UserID, + userID, ) if err != nil { return @@ -108,7 +108,7 @@ func (session Session) RootNode() (node Node, err error) {// {{{ defer rows.Close() node.Name = "Start" - node.UserID = session.UserID + node.UserID = userID node.Complete = true node.Children = []Node{} node.Crumbs = []Node{} @@ -129,13 +129,13 @@ func (session Session) RootNode() (node Node, err error) {// {{{ return }// }}} -func (session Session) Node(nodeID int) (node Node, err error) {// {{{ +func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{ if nodeID == 0 { - return session.RootNode() + return RootNode(userID) } var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` WITH RECURSIVE recurse AS ( SELECT id, @@ -170,7 +170,7 @@ func (session Session) Node(nodeID int) (node Node, err error) {// {{{ SELECT * FROM recurse ORDER BY level ASC `, - session.UserID, + userID, nodeID, ) if err != nil { @@ -217,14 +217,14 @@ func (session Session) Node(nodeID int) (node Node, err error) {// {{{ } } - node.Crumbs, err = session.NodeCrumbs(node.ID) - node.Files, err = session.Files(node.ID, 0) + node.Crumbs, err = NodeCrumbs(node.ID) + node.Files, err = Files(userID, node.ID, 0) return }// }}} -func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ +func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` WITH RECURSIVE nodes AS ( SELECT id, @@ -260,10 +260,10 @@ func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ } return }// }}} -func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{ +func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{ var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` INSERT INTO node(user_id, parent_id, name) VALUES($1, NULLIF($2, 0)::integer, $3) RETURNING @@ -273,7 +273,7 @@ func (session Session) CreateNode(parentID int, name string) (node Node, err err name, content `, - session.UserID, + userID, parentID, name, ) @@ -292,12 +292,12 @@ func (session Session) CreateNode(parentID int, name string) (node Node, err err node.Complete = true } - node.Crumbs, err = session.NodeCrumbs(node.ID) + node.Crumbs, err = NodeCrumbs(node.ID) return }// }}} -func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{ +func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{ if cryptoKeyID > 0 { - _, err = db.Exec(` + _, err = service.Db.Conn.Exec(` UPDATE node SET content = '', @@ -313,10 +313,10 @@ func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) ( content, cryptoKeyID, nodeID, - session.UserID, + userID, ) } else { - _, err = db.Exec(` + _, err = service.Db.Conn.Exec(` UPDATE node SET content = $1, @@ -332,24 +332,24 @@ func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) ( content, cryptoKeyID, nodeID, - session.UserID, + userID, ) } return }// }}} -func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{ - _, err = db.Exec(` +func RenameNode(userID, nodeID int, name string) (err error) {// {{{ + _, err = service.Db.Conn.Exec(` UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3 `, name, - session.UserID, + userID, nodeID, ) return }// }}} -func (session Session) DeleteNode(nodeID int) (err error) {// {{{ - _, err = db.Exec(` +func DeleteNode(userID, nodeID int) (err error) {// {{{ + _, err = service.Db.Conn.Exec(` WITH RECURSIVE nodetree AS ( SELECT id, parent_id @@ -368,15 +368,15 @@ func (session Session) DeleteNode(nodeID int) (err error) {// {{{ DELETE FROM node WHERE id IN ( SELECT id FROM nodetree )`, - session.UserID, + userID, nodeID, ) return }// }}} -func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{ +func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{ nodes = []Node{} var rows *sqlx.Rows - rows, err = db.Queryx(` + rows, err = service.Db.Conn.Queryx(` SELECT id, user_id, @@ -385,14 +385,15 @@ func (session Session) SearchNodes(search string) (nodes []Node, err error) {// updated FROM node WHERE + user_id = $1 crypto_key_id IS NULL AND ( - content ~* $1 OR - name ~* $1 + content ~* $2 OR + name ~* $2 ) ORDER BY updated DESC - `, search) + `, userID, search) if err != nil { return } diff --git a/session.go b/session.go deleted file mode 100644 index 53d163e..0000000 --- a/session.go +++ /dev/null @@ -1,119 +0,0 @@ -package main - -import ( - // Standard - "database/sql" - "errors" - "fmt" - "net/http" - "time" -) - -type Session struct { - UUID string - UserID int - Created time.Time -} - -func CreateSession() (session Session, err error) {// {{{ - var rows *sql.Rows - if rows, err = db.Query(` - INSERT INTO public.session(uuid) - VALUES(gen_random_uuid()) - RETURNING uuid, created`, - ); err != nil { - return - } - defer rows.Close() - - if rows.Next() { - rows.Scan(&session.UUID, &session.Created) - } - - return -}// }}} - -func sessionUUID(r *http.Request) (string, error) {// {{{ - headers := r.Header["X-Session-Id"] - if len(headers) > 0 { - return headers[0], nil - } - return "", errors.New("Invalid session") -}// }}} -func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{ - var uuid string - if uuid, err = sessionUUID(r); err != nil { - return - } - - session.UUID = uuid - if found, err = session.Retrieve(); err != nil { - return - } - - if notFoundIsError && !found { - err = errors.New("Invalid session") - return - } - - return -}// }}} - -func (session *Session) Retrieve() (found bool, err error) {// {{{ - var rows *sql.Rows - if rows, err = db.Query(` - SELECT - uuid, user_id, created - FROM public.session - WHERE - uuid = $1 AND - created + $2::interval >= NOW() - `, - session.UUID, - fmt.Sprintf("%d days", config.Session.DaysValid), - ); err != nil { - return - } - defer rows.Close() - - found = false - if rows.Next() { - found = true - rows.Scan(&session.UUID, &session.UserID, &session.Created) - } - - return -}// }}} -func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{ - var rows *sql.Rows - if rows, err = db.Query(` - SELECT id - FROM public.user - WHERE - username=$1 AND - password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea) - `, - username, - password, - ); err != nil { - return - } - defer rows.Close() - - if rows.Next() { - rows.Scan(&session.UserID) - authenticated = session.UserID > 0 - } - - if authenticated { - _, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID) - if err != nil { - return - } - } - - return -}// }}} - - -// vim: foldmethod=marker diff --git a/sql/0001.sql b/sql/00001.sql similarity index 100% rename from sql/0001.sql rename to sql/00001.sql diff --git a/sql/0002.sql b/sql/00002.sql similarity index 100% rename from sql/0002.sql rename to sql/00002.sql diff --git a/sql/0003.sql b/sql/00003.sql similarity index 100% rename from sql/0003.sql rename to sql/00003.sql diff --git a/sql/0004.sql b/sql/00004.sql similarity index 100% rename from sql/0004.sql rename to sql/00004.sql diff --git a/sql/0005.sql b/sql/00005.sql similarity index 100% rename from sql/0005.sql rename to sql/00005.sql diff --git a/sql/0006.sql b/sql/00006.sql similarity index 100% rename from sql/0006.sql rename to sql/00006.sql diff --git a/sql/0007.sql b/sql/00007.sql similarity index 100% rename from sql/0007.sql rename to sql/00007.sql diff --git a/sql/0008.sql b/sql/00008.sql similarity index 100% rename from sql/0008.sql rename to sql/00008.sql diff --git a/sql/0009.sql b/sql/00009.sql similarity index 100% rename from sql/0009.sql rename to sql/00009.sql diff --git a/sql/0010.sql b/sql/00010.sql similarity index 100% rename from sql/0010.sql rename to sql/00010.sql diff --git a/sql/0011.sql b/sql/00011.sql similarity index 100% rename from sql/0011.sql rename to sql/00011.sql diff --git a/sql/0012.sql b/sql/00012.sql similarity index 100% rename from sql/0012.sql rename to sql/00012.sql diff --git a/sql/0013.sql b/sql/0013.sql deleted file mode 100644 index d8fb23d..0000000 --- a/sql/0013.sql +++ /dev/null @@ -1,34 +0,0 @@ -/* Required for the gen_random_bytes function */ -CREATE EXTENSION pgcrypto; - -CREATE FUNCTION password_hash(salt_hex char(32), pass bytea) -RETURNS char(96) -LANGUAGE plpgsql -AS -$$ -BEGIN - RETURN ( - SELECT - salt_hex || - encode( - sha256( - decode(salt_hex, 'hex') || /* salt in binary */ - pass /* password */ - ), - 'hex' - ) - ); -END; -$$; - -/* Password has to be able to accommodate 96 characters instead of previous 64. - * It can't be char(96), because then the password would be padded to 96 characters. */ -ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar; - -/* Update all users with salted and hashed passwords */ -UPDATE public.user -SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea); - -/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash, - * and the varchar type is not longer necessary. */ -ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar; diff --git a/static/js/app.mjs b/static/js/app.mjs index ecb641f..740f475 100644 --- a/static/js/app.mjs +++ b/static/js/app.mjs @@ -13,8 +13,8 @@ class App extends Component { this.websocket = null this.websocket_int_ping = null this.websocket_int_reconnect = null - this.wsConnect() - this.wsLoop() + //this.wsConnect() // XXX + //this.wsLoop() // XXX this.session = new Session(this) this.session.initialize() @@ -25,7 +25,7 @@ class App extends Component { this.startNode = null - this.setStartNode() + //this.setStartNode() }//}}} render() {//{{{ let app_el = document.getElementById('app') @@ -114,7 +114,7 @@ class App extends Component { } if(this.session.UUID !== '') - headers['X-Session-Id'] = this.session.UUID + headers['X-Session-ID'] = this.session.UUID fetch(url, { method: 'POST', @@ -201,9 +201,24 @@ class Tree extends Component { this.selectedTreeNode = null this.props.app.tree = this + this.retrieve() + }//}}} + render({ app }) {//{{{ + let renderedTreeTrunk = this.treeTrunk.map(node=>{ + this.treeNodeComponents[node.ID] = createRef() + return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />` + }) + return html`
${renderedTreeTrunk}
` + }//}}} + retrieve(callback = null) {//{{{ this.props.app.request('/node/tree', { StartNodeID: 0 }) .then(res=>{ + this.treeNodes = {} + this.treeNodeComponents = {} + this.treeTrunk = [] + this.selectedTreeNode = null + // A tree of nodes is built. This requires the list of nodes // returned from the server to be sorted in such a way that // a parent node always appears before a child node. @@ -235,17 +250,12 @@ class Tree extends Component { this.crumbsUpdateNodes() this.forceUpdate() + if(callback) + callback() + }) .catch(this.responseError) }//}}} - render({ app }) {//{{{ - let renderedTreeTrunk = this.treeTrunk.map(node=>{ - this.treeNodeComponents[node.ID] = createRef() - return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />` - }) - return html`
${renderedTreeTrunk}
` - }//}}} - setSelected(node) {//{{{ if(this.selectedTreeNode) this.selectedTreeNode.selected.value = false diff --git a/static/js/node.mjs b/static/js/node.mjs index 6f1b4bf..15d3dfb 100644 --- a/static/js/node.mjs +++ b/static/js/node.mjs @@ -215,7 +215,14 @@ export class NodeUI extends Component { let name = prompt("Name") if(!name) return - this.node.value.create(name, nodeID=>this.goToNode(nodeID)) + this.node.value.create(name, nodeID=>{ + console.log('before', this.props.app.startNode) + this.props.app.startNode = new Node(this.props.app, nodeID) + console.log('after', this.props.app.startNode) + this.props.app.tree.retrieve(()=>{ + this.goToNode(nodeID) + }) + }) }//}}} saveNode() {//{{{ let nodeContent = this.nodeContent.current diff --git a/static/js/session.mjs b/static/js/session.mjs index 9eedf70..a5dcc06 100644 --- a/static/js/session.mjs +++ b/static/js/session.mjs @@ -10,11 +10,10 @@ export class Session { // Retrieving the stored session UUID, if any. // If one found, validate with server. - // If the browser doesn't know anything about a session, // a call to /session/create is necessary to retrieve a session UUID. - let uuid= window.localStorage.getItem("session.UUID") - if(uuid === null) { + let uuid = window.localStorage.getItem("session.UUID") + if (uuid === null) { this.create() return } @@ -25,47 +24,47 @@ export class Session { // A call to /session/retrieve with a session UUID validates that the // session is still valid and returns all session information. this.UUID = uuid - this.app.request('/session/retrieve', {}) - .then(res=>{ - if(res.Valid) { - // Session exists on server. - // Not necessarily authenticated. - this.UserID = res.Session.UserID // could be 0 - this.initialized = true - this.app.forceUpdate() - } else { - // Session has probably expired. A new is required. - this.create() - } - }) - .catch(this.app.responseError) + this.app.request('/_session/retrieve', {}) + .then(res => { + if (res.Error === undefined) { + // Session exists on server. + // Not necessarily authenticated. + this.UserID = res.UserID // could be 0 + this.initialized = true + this.app.forceUpdate() + } else { + // Session has probably expired. A new is required. + this.create() + } + }) + .catch(this.app.responseError) }//}}} create() {//{{{ - this.app.request('/session/create', {}) - .then(res=>{ - this.UUID = res.Session.UUID - window.localStorage.setItem('session.UUID', this.UUID) - this.initialized = true - this.app.forceUpdate() - }) - .catch(this.responseError) + this.app.request('/_session/new', {}) + .then(res => { + this.UUID = res.Session.UUID + window.localStorage.setItem('session.UUID', this.Session.UUID) + this.initialized = true + this.app.forceUpdate() + }) + .catch(this.responseError) }//}}} authenticate(username, password) {//{{{ this.app.login.current.authentication_failed.value = false - this.app.request('/session/authenticate', { + this.app.request('/_session/authenticate', { username, password, }) - .then(res=>{ - if(res.Authenticated) { - this.UserID = res.Session.UserID - this.app.forceUpdate() - } else { - this.app.login.current.authentication_failed.value = true - } - }) - .catch(this.app.responseError) + .then(res => { + if (res.Authenticated) { + this.UserID = res.Session.UserID + this.app.forceUpdate() + } else { + this.app.login.current.authentication_failed.value = true + } + }) + .catch(this.app.responseError) }//}}} authenticated() {//{{{ return this.UserID != 0 diff --git a/user.go b/user.go index c195b90..cfa5232 100644 --- a/user.go +++ b/user.go @@ -1,11 +1,6 @@ package main -import ( - // Standard - "database/sql" - "fmt" -) - +/* func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) { var result sql.Result var rowsAffected int64 @@ -14,10 +9,10 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er UPDATE public.user SET password = password_hash( - /* salt in hex */ + / salt in hex / ENCODE(gen_random_bytes(16), 'hex'), - /* password */ + / password / $1::bytea ) WHERE @@ -36,38 +31,4 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er return rowsAffected > 0, nil } - -func createUser() (err error) { - var username, password string - fmt.Printf("Username: ") - fmt.Scanln(&username) - fmt.Printf("Password: ") - fmt.Scanln(&password) - - err = CreateDbUser(username, password) - return -} - -func CreateDbUser(username string, password string) (err error) { - var result sql.Result - - result, err = db.Exec(` - INSERT INTO public.user(username, password) - VALUES( - $1, - password_hash( - /* salt in hex */ - ENCODE(gen_random_bytes(16), 'hex'), - - /* password */ - $2::bytea - ) - ) - `, - username, - password, - ) - - _, err = result.RowsAffected() - return -} +*/