wip: rewrite to webservice library

This commit is contained in:
Magnus Åhall 2024-01-05 20:00:02 +01:00
parent abbd320b93
commit 52fba2289e
23 changed files with 201 additions and 680 deletions

115
db.go
View File

@ -6,11 +6,7 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
// Standard // Standard
"errors"
"fmt" "fmt"
"io/fs"
"regexp"
"strconv"
) )
var ( var (
@ -33,117 +29,6 @@ func dbInit() (err error) { // {{{
if db, err = sqlx.Connect("postgres", dbConn); err != nil { if db, err = sqlx.Connect("postgres", dbConn); err != nil {
return return
} }
if err = dbVerifyInternals(); err != nil {
return
}
err = dbUpdate()
return
} // }}}
func dbVerifyInternals() (err error) { // {{{
var rows *sqlx.Rows
if rows, err = db.Queryx(
`SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '_internal'
AND c.relname = 'db'
)`,
); err != nil {
return
}
defer rows.Close()
var exists bool
rows.Next()
if err = rows.Scan(&exists); err != nil {
return
}
if !exists {
logger.Info("db", "op", "create_db", "db", "_internal.db")
if _, err = db.Exec(`
CREATE SCHEMA "_internal";
CREATE TABLE "_internal".db (
"key" varchar NOT NULL,
value varchar NULL,
CONSTRAINT db_pk PRIMARY KEY (key)
);
INSERT INTO _internal.db("key", "value")
VALUES('schema', '0');
`,
); err != nil {
return
}
}
return
} // }}}
func dbUpdate() (err error) { // {{{
/* Current schema revision is read from database.
* Used to iterate through the embedded SQL updates
* up to the db schema version currently compiled
* program is made for. */
var rows *sqlx.Rows
var schemaStr string
var schema int
rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return errors.New("Table _interval.db missing schema row")
}
if err = rows.Scan(&schemaStr); err != nil {
return
}
// Run updates
schema, err = strconv.Atoi(schemaStr)
if err != nil {
return err
}
sqlSchemaVersion := sqlSchema()
for i := (schema + 1); i <= sqlSchemaVersion; i++ {
logger.Info("db", "op", "upgrade_schema", "schema", i)
sql, _ := embedded.ReadFile(
fmt.Sprintf("sql/%04d.sql", i),
)
_, err = db.Exec(string(sql))
if err != nil {
return
}
_, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i)
if err != nil {
return
}
logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok")
}
return
} // }}}
func sqlSchema() (max int) { // {{{
var num int
files, _ := fs.ReadDir(embedded, "sql")
sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`)
for _, file := range files {
fname := sqlFilename.FindStringSubmatch(file.Name())
if len(fname) != 2 {
continue
}
num, _ = strconv.Atoi(fname[1])
}
if num > max {
max = num
}
return return
} // }}} } // }}}

14
file.go
View File

@ -5,7 +5,6 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
// Standard // Standard
"fmt"
"time" "time"
) )
@ -20,11 +19,11 @@ type File struct {
Uploaded time.Time Uploaded time.Time
} }
func (session Session) AddFile(file *File) (err error) { // {{{ func AddFile(userID int, file *File) (err error) { // {{{
file.UserID = session.UserID file.UserID = userID
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
INSERT INTO file(user_id, node_id, filename, size, mime, md5) INSERT INTO file(user_id, node_id, filename, size, mime, md5)
VALUES($1, $2, $3, $4, $5, $6) VALUES($1, $2, $3, $4, $5, $6)
RETURNING id RETURNING id
@ -43,12 +42,11 @@ func (session Session) AddFile(file *File) (err error) { // {{{
rows.Next() rows.Next()
err = rows.Scan(&file.ID) err = rows.Scan(&file.ID)
fmt.Printf("%#v\n", file)
return return
} // }}} } // }}}
func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{ func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx( rows, err = service.Db.Conn.Queryx(
`SELECT * `SELECT *
FROM file FROM file
WHERE WHERE
@ -58,7 +56,7 @@ func (session Session) Files(nodeID, fileID int) (files []File, err error) { //
WHEN 0 THEN true WHEN 0 THEN true
ELSE id = $3 ELSE id = $3
END`, END`,
session.UserID, userID,
nodeID, nodeID,
fileID, fileID,
) )

22
key.go
View File

@ -3,8 +3,6 @@ package main
import ( import (
// External // External
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
// Standard
) )
type Key struct { type Key struct {
@ -14,9 +12,9 @@ type Key struct {
Key string Key string
} }
func (session Session) Keys() (keys []Key, err error) {// {{{ func Keys(userID int) (keys []Key, err error) { // {{{
var rows *sqlx.Rows var rows *sqlx.Rows
if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil { if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil {
return return
} }
defer rows.Close() defer rows.Close()
@ -31,14 +29,14 @@ func (session Session) Keys() (keys []Key, err error) {// {{{
} }
return return
}// }}} } // }}}
func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{ func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{
var row *sqlx.Rows var row *sqlx.Rows
if row, err = db.Queryx( if row, err = service.Db.Conn.Queryx(
`INSERT INTO crypto_key(user_id, description, key) `INSERT INTO crypto_key(user_id, description, key)
VALUES($1, $2, $3) VALUES($1, $2, $3)
RETURNING *`, RETURNING *`,
session.UserID, userID,
description, description,
keyEncoded, keyEncoded,
); err != nil { ); err != nil {
@ -54,10 +52,10 @@ func (session Session) KeyCreate(description, keyEncoded string) (key Key, err e
} }
return return
}// }}} } // }}}
func (session Session) KeyCounter() (counter int64, err error) {// {{{ func KeyCounter() (counter int64, err error) { // {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`) rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
if err != nil { if err != nil {
return return
} }
@ -66,6 +64,6 @@ func (session Session) KeyCounter() (counter int64, err error) {// {{{
rows.Next() rows.Next()
err = rows.Scan(&counter) err = rows.Scan(&counter)
return return
}// }}} } // }}}
// vim: foldmethod=marker // vim: foldmethod=marker

347
main.go
View File

@ -1,20 +1,23 @@
package main package main
import ( import (
// External
"git.gibonuddevalla.se/go/webservice"
// Internal
"git.gibonuddevalla.se/go/webservice/session"
// Standard // Standard
"crypto/md5" "crypto/md5"
"embed" "embed"
"encoding/hex" "encoding/hex"
"flag" "flag"
"fmt" "fmt"
"html/template"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"path"
"path/filepath" "path/filepath"
"regexp"
"strconv" "strconv"
"strings" "strings"
) )
@ -25,8 +28,10 @@ var (
flagPort int flagPort int
flagVersion bool flagVersion bool
flagCreateUser bool flagCreateUser bool
flagCheckLocal bool
flagConfig string flagConfig string
service *webservice.Service
connectionManager ConnectionManager connectionManager ConnectionManager
static http.Handler static http.Handler
config Config config Config
@ -34,11 +39,24 @@ var (
VERSION string VERSION string
//go:embed version sql/* //go:embed version sql/*
embedded embed.FS embeddedSQL embed.FS
//go:embed static
staticFS embed.FS
) )
func sqlProvider(dbname string, version int) (sql []byte, found bool) {
var err error
sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version))
if err != nil {
return
}
found = true
return
}
func init() { // {{{ func init() { // {{{
version, _ := embedded.ReadFile("version") version, _ := embeddedSQL.ReadFile("version")
VERSION = strings.TrimSpace(string(version)) VERSION = strings.TrimSpace(string(version))
opt := slog.HandlerOptions{} opt := slog.HandlerOptions{}
@ -48,6 +66,7 @@ func init() { // {{{
flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on") flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on")
flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists") flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists")
flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit") flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit")
flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded")
flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file") flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file")
flag.Parse() flag.Parse()
} // }}} } // }}}
@ -58,56 +77,47 @@ func main() { // {{{
fmt.Printf("%s\n", VERSION) fmt.Printf("%s\n", VERSION)
os.Exit(0) os.Exit(0)
} }
logger.Info("application", "version", VERSION) logger.Info("application", "version", VERSION)
config, err = ConfigRead(flagConfig) service, err = webservice.New(flagConfig, VERSION)
if err != nil { if err != nil {
logger.Error("config", "error", err) logger.Error("application", "error", err)
os.Exit(1) os.Exit(1)
} }
service.SetDatabase(sqlProvider)
service.SetStaticDirectory("static", true)
service.SetStaticFS(staticFS, "static")
if err = dbInit(); err != nil { service.Register("/node/upload", true, true, nodeUpload)
logger.Error("db", "error", err) service.Register("/node/tree", true, true, nodeTree)
os.Exit(1) service.Register("/node/retrieve", true, true, nodeRetrieve)
} service.Register("/node/create", true, true, nodeCreate)
service.Register("/node/update", true, true, nodeUpdate)
service.Register("/node/rename", true, true, nodeRename)
service.Register("/node/delete", true, true, nodeDelete)
service.Register("/node/download", true, true, nodeDownload)
service.Register("/node/search", true, true, nodeSearch)
service.Register("/key/retrieve", true, true, keyRetrieve)
service.Register("/key/create", true, true, keyCreate)
service.Register("/key/counter", true, true, keyCounter)
service.Register("/ws", false, false, service.WebsocketHandler)
service.Register("/", false, false, service.StaticHandler)
if flagCreateUser { if flagCreateUser {
err = createUser() service.CreateUserPrompt()
if err != nil {
logger.Error("db", "error", err)
os.Exit(1)
}
os.Exit(0) os.Exit(0)
} }
err = service.Start()
if err != nil {
logger.Error("webserver", "error", err)
os.Exit(1)
}
connectionManager = NewConnectionManager() connectionManager = NewConnectionManager()
go connectionManager.BroadcastLoop() go connectionManager.BroadcastLoop()
static = http.FileServer(http.Dir(config.Application.Directories.Static))
http.HandleFunc("/css_updated", cssUpdateHandler) http.HandleFunc("/css_updated", cssUpdateHandler)
http.HandleFunc("/session/create", sessionCreate)
http.HandleFunc("/session/retrieve", sessionRetrieve)
http.HandleFunc("/session/authenticate", sessionAuthenticate)
http.HandleFunc("/user/password", userPassword)
http.HandleFunc("/node/tree", nodeTree)
http.HandleFunc("/node/retrieve", nodeRetrieve)
http.HandleFunc("/node/create", nodeCreate)
http.HandleFunc("/node/update", nodeUpdate)
http.HandleFunc("/node/rename", nodeRename)
http.HandleFunc("/node/delete", nodeDelete)
http.HandleFunc("/node/upload", nodeUpload)
http.HandleFunc("/node/download", nodeDownload)
http.HandleFunc("/node/search", nodeSearch)
http.HandleFunc("/key/retrieve", keyRetrieve)
http.HandleFunc("/key/create", keyCreate)
http.HandleFunc("/key/counter", keyCounter)
http.HandleFunc("/ws", websocketHandler)
http.HandleFunc("/", staticHandler)
listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort)
logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains)
http.ListenAndServe(listen, nil)
} // }}} } // }}}
func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{ func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{
@ -127,108 +137,8 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
} // }}} } // }}}
func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{
data := struct {
VERSION string
}{
VERSION: VERSION,
}
// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
// To get rid of problems with cached content in browser on a new version release,
// while also not disabling cache altogether.
logger.Debug("webserver", "request", r.URL.Path)
if r.URL.Path == "/favicon.ico" {
static.ServeHTTP(w, r)
return
}
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
static.ServeHTTP(w, r)
return
}
// Everything else is run through the template system.
// For now to get VERSION into files to fix caching.
logger.Debug("webserver", "template", r.URL.Path)
tmpl, err := newTemplate(r.URL.Path)
if err != nil {
if os.IsNotExist(err) {
w.WriteHeader(404)
}
w.Write([]byte(err.Error()))
return
}
if err = tmpl.Execute(w, data); err != nil {
w.Write([]byte(err.Error()))
}
} // }}}
func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/create")
session, err := CreateSession()
if err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Session": session,
})
} // }}}
func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/retrieve")
var err error
var found bool
var session Session
if session, found, err = ValidateSession(r, false); err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Valid": found,
"Session": session,
})
} // }}}
func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/authenticate")
var err error
var session Session
var authenticated bool
// Validate session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Username string
Password string
}{}
if err = parseRequest(r, &req); err != nil {
responseError(w, err)
return
}
if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Authenticated": authenticated,
"Session": session,
})
} // }}}
/*
func userPassword(w http.ResponseWriter, r *http.Request) { // {{{ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
var err error var err error
var ok bool var ok bool
@ -259,15 +169,10 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
"CurrentPasswordOK": ok, "CurrentPasswordOK": ok,
}) })
} // }}} } // }}}
*/
func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct{ StartNodeID int }{} req := struct{ StartNodeID int }{}
if err = parseRequest(r, &req); err != nil { if err = parseRequest(r, &req); err != nil {
@ -275,7 +180,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
nodes, err := session.NodeTree(req.StartNodeID) nodes, err := NodeTree(sess.UserID, req.StartNodeID)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -286,15 +191,9 @@ func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
"Nodes": nodes, "Nodes": nodes,
}) })
} // }}} } // }}}
func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/retrieve")
var err error var err error
var session Session logger.Info("webserver", "request", "/node/retrieve")
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct{ ID int }{} req := struct{ ID int }{}
if err = parseRequest(r, &req); err != nil { if err = parseRequest(r, &req); err != nil {
@ -302,7 +201,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
node, err := session.Node(req.ID) node, err := RetrieveNode(sess.UserID, req.ID)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -313,15 +212,9 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
"Node": node, "Node": node,
}) })
} // }}} } // }}}
func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/create") logger.Info("webserver", "request", "/node/create")
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
Name string Name string
@ -332,7 +225,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
node, err := session.CreateNode(req.ParentID, req.Name) node, err := CreateNode(sess.UserID, req.ParentID, req.Name)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -343,15 +236,9 @@ func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
"Node": node, "Node": node,
}) })
} // }}} } // }}}
func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/update") logger.Info("webserver", "request", "/node/update")
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
NodeID int NodeID int
@ -363,7 +250,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID) err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -373,16 +260,9 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
"OK": true, "OK": true,
}) })
} // }}} } // }}}
func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/rename")
var err error var err error
var session Session logger.Info("webserver", "request", "/node/rename")
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
NodeID int NodeID int
@ -393,7 +273,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
err = session.RenameNode(req.NodeID, req.Name) err = RenameNode(sess.UserID, req.NodeID, req.Name)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -403,16 +283,9 @@ func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
"OK": true, "OK": true,
}) })
} // }}} } // }}}
func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/delete")
var err error var err error
var session Session logger.Info("webserver", "request", "/node/delete")
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
NodeID int NodeID int
@ -422,7 +295,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
err = session.DeleteNode(req.NodeID) err = DeleteNode(sess.UserID, req.NodeID)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -432,15 +305,9 @@ func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
"OK": true, "OK": true,
}) })
} // }}} } // }}}
func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/upload") logger.Info("webserver", "request", "/node/upload")
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
// Parse our multipart form, 10 << 20 specifies a maximum // Parse our multipart form, 10 << 20 specifies a maximum
// upload of 10 MB files. // upload of 10 MB files.
@ -479,7 +346,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
MIME: handler.Header.Get("Content-Type"), MIME: handler.Header.Get("Content-Type"),
MD5: md5sum, MD5: md5sum,
} }
if err = session.AddFile(&nodeFile); err != nil { if err = AddFile(sess.UserID, &nodeFile); err != nil {
responseError(w, err) responseError(w, err)
return return
} }
@ -516,17 +383,11 @@ func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
"File": nodeFile, "File": nodeFile,
}) })
} // }}} } // }}}
func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/download") logger.Info("webserver", "request", "/node/download")
var err error var err error
var session Session
var files []File var files []File
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
NodeID int NodeID int
FileID int FileID int
@ -536,7 +397,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
files, err = session.Files(req.NodeID, req.FileID) files, err = Files(sess.UserID, req.NodeID, req.FileID)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -583,17 +444,11 @@ func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
} }
} // }}} } // }}}
func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/files") logger.Info("webserver", "request", "/node/files")
var err error var err error
var session Session
var files []File var files []File
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
NodeID int NodeID int
}{} }{}
@ -602,7 +457,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
files, err = session.Files(req.NodeID, 0) files, err = Files(sess.UserID, req.NodeID, 0)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -613,17 +468,11 @@ func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
"Files": files, "Files": files,
}) })
} // }}} } // }}}
func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/search") logger.Info("webserver", "request", "/node/search")
var err error var err error
var session Session
var nodes []Node var nodes []Node
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
Search string Search string
}{} }{}
@ -632,7 +481,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
nodes, err = session.SearchNodes(req.Search) nodes, err = SearchNodes(sess.UserID, req.Search)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -644,17 +493,11 @@ func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
}) })
} // }}} } // }}}
func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{ func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/key/retrieve") logger.Info("webserver", "request", "/key/retrieve")
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil { keys, err := Keys(sess.UserID)
responseError(w, err)
return
}
keys, err := session.Keys()
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -665,15 +508,9 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
"Keys": keys, "Keys": keys,
}) })
} // }}} } // }}}
func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/key/create") logger.Info("webserver", "request", "/key/create")
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct { req := struct {
Description string Description string
@ -684,7 +521,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
return return
} }
key, err := session.KeyCreate(req.Description, req.Key) key, err := KeyCreate(sess.UserID, req.Description, req.Key)
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -695,17 +532,11 @@ func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
"Key": key, "Key": key,
}) })
} // }}} } // }}}
func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{ func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/key/counter") logger.Info("webserver", "request", "/key/counter")
var err error var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil { counter, err := KeyCounter()
responseError(w, err)
return
}
counter, err := session.KeyCounter()
if err != nil { if err != nil {
responseError(w, err) responseError(w, err)
return return
@ -718,20 +549,4 @@ func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{
}) })
} // }}} } // }}}
func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
// Append index.html if needed for further reading of the file
p := requestPath
if p[len(p)-1] == '/' {
p += "index.html"
}
p = config.Application.Directories.Static + p
base := path.Base(p)
if tmpl, err = template.New(base).ParseFiles(p); err != nil {
return
}
return
} // }}}
// vim: foldmethod=marker // vim: foldmethod=marker

71
node.go
View File

@ -25,9 +25,9 @@ type Node struct {
ContentEncrypted string `db:"content_encrypted" json:"-"` ContentEncrypted string `db:"content_encrypted" json:"-"`
} }
func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
WITH RECURSIVE nodetree AS ( WITH RECURSIVE nodetree AS (
SELECT SELECT
*, *,
@ -62,7 +62,7 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {
ORDER BY ORDER BY
path ASC path ASC
`, `,
session.UserID, userID,
startNodeID, startNodeID,
) )
if err != nil { if err != nil {
@ -87,9 +87,9 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {
return return
}// }}} }// }}}
func (session Session) RootNode() (node Node, err error) {// {{{ func RootNode(userID int) (node Node, err error) {// {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
SELECT SELECT
id, id,
user_id, user_id,
@ -100,7 +100,7 @@ func (session Session) RootNode() (node Node, err error) {// {{{
user_id = $1 AND user_id = $1 AND
parent_id IS NULL parent_id IS NULL
`, `,
session.UserID, userID,
) )
if err != nil { if err != nil {
return return
@ -108,7 +108,7 @@ func (session Session) RootNode() (node Node, err error) {// {{{
defer rows.Close() defer rows.Close()
node.Name = "Start" node.Name = "Start"
node.UserID = session.UserID node.UserID = userID
node.Complete = true node.Complete = true
node.Children = []Node{} node.Children = []Node{}
node.Crumbs = []Node{} node.Crumbs = []Node{}
@ -129,13 +129,13 @@ func (session Session) RootNode() (node Node, err error) {// {{{
return return
}// }}} }// }}}
func (session Session) Node(nodeID int) (node Node, err error) {// {{{ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
if nodeID == 0 { if nodeID == 0 {
return session.RootNode() return RootNode(userID)
} }
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
WITH RECURSIVE recurse AS ( WITH RECURSIVE recurse AS (
SELECT SELECT
id, id,
@ -170,7 +170,7 @@ func (session Session) Node(nodeID int) (node Node, err error) {// {{{
SELECT * FROM recurse ORDER BY level ASC SELECT * FROM recurse ORDER BY level ASC
`, `,
session.UserID, userID,
nodeID, nodeID,
) )
if err != nil { if err != nil {
@ -217,14 +217,14 @@ func (session Session) Node(nodeID int) (node Node, err error) {// {{{
} }
} }
node.Crumbs, err = session.NodeCrumbs(node.ID) node.Crumbs, err = NodeCrumbs(node.ID)
node.Files, err = session.Files(node.ID, 0) node.Files, err = Files(userID, node.ID, 0)
return return
}// }}} }// }}}
func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{ func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
WITH RECURSIVE nodes AS ( WITH RECURSIVE nodes AS (
SELECT SELECT
id, id,
@ -260,10 +260,10 @@ func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
} }
return return
}// }}} }// }}}
func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
INSERT INTO node(user_id, parent_id, name) INSERT INTO node(user_id, parent_id, name)
VALUES($1, NULLIF($2, 0)::integer, $3) VALUES($1, NULLIF($2, 0)::integer, $3)
RETURNING RETURNING
@ -273,7 +273,7 @@ func (session Session) CreateNode(parentID int, name string) (node Node, err err
name, name,
content content
`, `,
session.UserID, userID,
parentID, parentID,
name, name,
) )
@ -292,12 +292,12 @@ func (session Session) CreateNode(parentID int, name string) (node Node, err err
node.Complete = true node.Complete = true
} }
node.Crumbs, err = session.NodeCrumbs(node.ID) node.Crumbs, err = NodeCrumbs(node.ID)
return return
}// }}} }// }}}
func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{
if cryptoKeyID > 0 { if cryptoKeyID > 0 {
_, err = db.Exec(` _, err = service.Db.Conn.Exec(`
UPDATE node UPDATE node
SET SET
content = '', content = '',
@ -313,10 +313,10 @@ func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (
content, content,
cryptoKeyID, cryptoKeyID,
nodeID, nodeID,
session.UserID, userID,
) )
} else { } else {
_, err = db.Exec(` _, err = service.Db.Conn.Exec(`
UPDATE node UPDATE node
SET SET
content = $1, content = $1,
@ -332,24 +332,24 @@ func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (
content, content,
cryptoKeyID, cryptoKeyID,
nodeID, nodeID,
session.UserID, userID,
) )
} }
return return
}// }}} }// }}}
func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{ func RenameNode(userID, nodeID int, name string) (err error) {// {{{
_, err = db.Exec(` _, err = service.Db.Conn.Exec(`
UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3 UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3
`, `,
name, name,
session.UserID, userID,
nodeID, nodeID,
) )
return return
}// }}} }// }}}
func (session Session) DeleteNode(nodeID int) (err error) {// {{{ func DeleteNode(userID, nodeID int) (err error) {// {{{
_, err = db.Exec(` _, err = service.Db.Conn.Exec(`
WITH RECURSIVE nodetree AS ( WITH RECURSIVE nodetree AS (
SELECT SELECT
id, parent_id id, parent_id
@ -368,15 +368,15 @@ func (session Session) DeleteNode(nodeID int) (err error) {// {{{
DELETE FROM node WHERE id IN ( DELETE FROM node WHERE id IN (
SELECT id FROM nodetree SELECT id FROM nodetree
)`, )`,
session.UserID, userID,
nodeID, nodeID,
) )
return return
}// }}} }// }}}
func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{ func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
nodes = []Node{} nodes = []Node{}
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Queryx(` rows, err = service.Db.Conn.Queryx(`
SELECT SELECT
id, id,
user_id, user_id,
@ -385,14 +385,15 @@ func (session Session) SearchNodes(search string) (nodes []Node, err error) {//
updated updated
FROM node FROM node
WHERE WHERE
user_id = $1
crypto_key_id IS NULL AND crypto_key_id IS NULL AND
( (
content ~* $1 OR content ~* $2 OR
name ~* $1 name ~* $2
) )
ORDER BY ORDER BY
updated DESC updated DESC
`, search) `, userID, search)
if err != nil { if err != nil {
return return
} }

View File

@ -1,119 +0,0 @@
package main
import (
// Standard
"database/sql"
"errors"
"fmt"
"net/http"
"time"
)
type Session struct {
UUID string
UserID int
Created time.Time
}
func CreateSession() (session Session, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
INSERT INTO public.session(uuid)
VALUES(gen_random_uuid())
RETURNING uuid, created`,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&session.UUID, &session.Created)
}
return
}// }}}
func sessionUUID(r *http.Request) (string, error) {// {{{
headers := r.Header["X-Session-Id"]
if len(headers) > 0 {
return headers[0], nil
}
return "", errors.New("Invalid session")
}// }}}
func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{
var uuid string
if uuid, err = sessionUUID(r); err != nil {
return
}
session.UUID = uuid
if found, err = session.Retrieve(); err != nil {
return
}
if notFoundIsError && !found {
err = errors.New("Invalid session")
return
}
return
}// }}}
func (session *Session) Retrieve() (found bool, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
SELECT
uuid, user_id, created
FROM public.session
WHERE
uuid = $1 AND
created + $2::interval >= NOW()
`,
session.UUID,
fmt.Sprintf("%d days", config.Session.DaysValid),
); err != nil {
return
}
defer rows.Close()
found = false
if rows.Next() {
found = true
rows.Scan(&session.UUID, &session.UserID, &session.Created)
}
return
}// }}}
func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
SELECT id
FROM public.user
WHERE
username=$1 AND
password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`,
username,
password,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&session.UserID)
authenticated = session.UserID > 0
}
if authenticated {
_, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID)
if err != nil {
return
}
}
return
}// }}}
// vim: foldmethod=marker

View File

@ -1,34 +0,0 @@
/* Required for the gen_random_bytes function */
CREATE EXTENSION pgcrypto;
CREATE FUNCTION password_hash(salt_hex char(32), pass bytea)
RETURNS char(96)
LANGUAGE plpgsql
AS
$$
BEGIN
RETURN (
SELECT
salt_hex ||
encode(
sha256(
decode(salt_hex, 'hex') || /* salt in binary */
pass /* password */
),
'hex'
)
);
END;
$$;
/* Password has to be able to accommodate 96 characters instead of previous 64.
* It can't be char(96), because then the password would be padded to 96 characters. */
ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar;
/* Update all users with salted and hashed passwords */
UPDATE public.user
SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea);
/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash,
* and the varchar type is not longer necessary. */
ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar;

View File

@ -13,8 +13,8 @@ class App extends Component {
this.websocket = null this.websocket = null
this.websocket_int_ping = null this.websocket_int_ping = null
this.websocket_int_reconnect = null this.websocket_int_reconnect = null
this.wsConnect() //this.wsConnect() // XXX
this.wsLoop() //this.wsLoop() // XXX
this.session = new Session(this) this.session = new Session(this)
this.session.initialize() this.session.initialize()
@ -25,7 +25,7 @@ class App extends Component {
this.startNode = null this.startNode = null
this.setStartNode() //this.setStartNode()
}//}}} }//}}}
render() {//{{{ render() {//{{{
let app_el = document.getElementById('app') let app_el = document.getElementById('app')
@ -114,7 +114,7 @@ class App extends Component {
} }
if(this.session.UUID !== '') if(this.session.UUID !== '')
headers['X-Session-Id'] = this.session.UUID headers['X-Session-ID'] = this.session.UUID
fetch(url, { fetch(url, {
method: 'POST', method: 'POST',
@ -201,9 +201,24 @@ class Tree extends Component {
this.selectedTreeNode = null this.selectedTreeNode = null
this.props.app.tree = this this.props.app.tree = this
this.retrieve()
}//}}}
render({ app }) {//{{{
let renderedTreeTrunk = this.treeTrunk.map(node=>{
this.treeNodeComponents[node.ID] = createRef()
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
})
return html`<div id="tree">${renderedTreeTrunk}</div>`
}//}}}
retrieve(callback = null) {//{{{
this.props.app.request('/node/tree', { StartNodeID: 0 }) this.props.app.request('/node/tree', { StartNodeID: 0 })
.then(res=>{ .then(res=>{
this.treeNodes = {}
this.treeNodeComponents = {}
this.treeTrunk = []
this.selectedTreeNode = null
// A tree of nodes is built. This requires the list of nodes // A tree of nodes is built. This requires the list of nodes
// returned from the server to be sorted in such a way that // returned from the server to be sorted in such a way that
// a parent node always appears before a child node. // a parent node always appears before a child node.
@ -235,17 +250,12 @@ class Tree extends Component {
this.crumbsUpdateNodes() this.crumbsUpdateNodes()
this.forceUpdate() this.forceUpdate()
if(callback)
callback()
}) })
.catch(this.responseError) .catch(this.responseError)
}//}}} }//}}}
render({ app }) {//{{{
let renderedTreeTrunk = this.treeTrunk.map(node=>{
this.treeNodeComponents[node.ID] = createRef()
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
})
return html`<div id="tree">${renderedTreeTrunk}</div>`
}//}}}
setSelected(node) {//{{{ setSelected(node) {//{{{
if(this.selectedTreeNode) if(this.selectedTreeNode)
this.selectedTreeNode.selected.value = false this.selectedTreeNode.selected.value = false

View File

@ -215,7 +215,14 @@ export class NodeUI extends Component {
let name = prompt("Name") let name = prompt("Name")
if(!name) if(!name)
return return
this.node.value.create(name, nodeID=>this.goToNode(nodeID)) this.node.value.create(name, nodeID=>{
console.log('before', this.props.app.startNode)
this.props.app.startNode = new Node(this.props.app, nodeID)
console.log('after', this.props.app.startNode)
this.props.app.tree.retrieve(()=>{
this.goToNode(nodeID)
})
})
}//}}} }//}}}
saveNode() {//{{{ saveNode() {//{{{
let nodeContent = this.nodeContent.current let nodeContent = this.nodeContent.current

View File

@ -10,11 +10,10 @@ export class Session {
// Retrieving the stored session UUID, if any. // Retrieving the stored session UUID, if any.
// If one found, validate with server. // If one found, validate with server.
// If the browser doesn't know anything about a session, // If the browser doesn't know anything about a session,
// a call to /session/create is necessary to retrieve a session UUID. // a call to /session/create is necessary to retrieve a session UUID.
let uuid= window.localStorage.getItem("session.UUID") let uuid = window.localStorage.getItem("session.UUID")
if(uuid === null) { if (uuid === null) {
this.create() this.create()
return return
} }
@ -25,12 +24,12 @@ export class Session {
// A call to /session/retrieve with a session UUID validates that the // A call to /session/retrieve with a session UUID validates that the
// session is still valid and returns all session information. // session is still valid and returns all session information.
this.UUID = uuid this.UUID = uuid
this.app.request('/session/retrieve', {}) this.app.request('/_session/retrieve', {})
.then(res=>{ .then(res => {
if(res.Valid) { if (res.Error === undefined) {
// Session exists on server. // Session exists on server.
// Not necessarily authenticated. // Not necessarily authenticated.
this.UserID = res.Session.UserID // could be 0 this.UserID = res.UserID // could be 0
this.initialized = true this.initialized = true
this.app.forceUpdate() this.app.forceUpdate()
} else { } else {
@ -41,10 +40,10 @@ export class Session {
.catch(this.app.responseError) .catch(this.app.responseError)
}//}}} }//}}}
create() {//{{{ create() {//{{{
this.app.request('/session/create', {}) this.app.request('/_session/new', {})
.then(res=>{ .then(res => {
this.UUID = res.Session.UUID this.UUID = res.Session.UUID
window.localStorage.setItem('session.UUID', this.UUID) window.localStorage.setItem('session.UUID', this.Session.UUID)
this.initialized = true this.initialized = true
this.app.forceUpdate() this.app.forceUpdate()
}) })
@ -53,12 +52,12 @@ export class Session {
authenticate(username, password) {//{{{ authenticate(username, password) {//{{{
this.app.login.current.authentication_failed.value = false this.app.login.current.authentication_failed.value = false
this.app.request('/session/authenticate', { this.app.request('/_session/authenticate', {
username, username,
password, password,
}) })
.then(res=>{ .then(res => {
if(res.Authenticated) { if (res.Authenticated) {
this.UserID = res.Session.UserID this.UserID = res.Session.UserID
this.app.forceUpdate() this.app.forceUpdate()
} else { } else {

47
user.go
View File

@ -1,11 +1,6 @@
package main package main
import ( /*
// Standard
"database/sql"
"fmt"
)
func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) { func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) {
var result sql.Result var result sql.Result
var rowsAffected int64 var rowsAffected int64
@ -14,10 +9,10 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
UPDATE public.user UPDATE public.user
SET SET
password = password_hash( password = password_hash(
/* salt in hex */ / salt in hex /
ENCODE(gen_random_bytes(16), 'hex'), ENCODE(gen_random_bytes(16), 'hex'),
/* password */ / password /
$1::bytea $1::bytea
) )
WHERE WHERE
@ -36,38 +31,4 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
return rowsAffected > 0, nil return rowsAffected > 0, nil
} }
*/
func createUser() (err error) {
var username, password string
fmt.Printf("Username: ")
fmt.Scanln(&username)
fmt.Printf("Password: ")
fmt.Scanln(&password)
err = CreateDbUser(username, password)
return
}
func CreateDbUser(username string, password string) (err error) {
var result sql.Result
result, err = db.Exec(`
INSERT INTO public.user(username, password)
VALUES(
$1,
password_hash(
/* salt in hex */
ENCODE(gen_random_bytes(16), 'hex'),
/* password */
$2::bytea
)
)
`,
username,
password,
)
_, err = result.RowsAffected()
return
}