Compare commits

..

2 Commits

Author SHA1 Message Date
Magnus Åhall
afa113b8b7 User session fixes 2024-01-05 21:14:55 +01:00
Magnus Åhall
52fba2289e wip: rewrite to webservice library 2024-01-05 20:00:02 +01:00
23 changed files with 205 additions and 715 deletions

150
db.go
View File

@ -1,150 +0,0 @@
package main
import (
// External
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
// Standard
"errors"
"fmt"
"io/fs"
"regexp"
"strconv"
)
var (
dbConn string
db *sqlx.DB
)
func dbInit() (err error) { // {{{
dbConn = fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
config.Database.Host,
config.Database.Port,
config.Database.Username,
config.Database.Password,
config.Database.Name,
)
logger.Info("db", "op", "connect", "host", config.Database.Host, "port", config.Database.Port)
if db, err = sqlx.Connect("postgres", dbConn); err != nil {
return
}
if err = dbVerifyInternals(); err != nil {
return
}
err = dbUpdate()
return
} // }}}
func dbVerifyInternals() (err error) { // {{{
var rows *sqlx.Rows
if rows, err = db.Queryx(
`SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '_internal'
AND c.relname = 'db'
)`,
); err != nil {
return
}
defer rows.Close()
var exists bool
rows.Next()
if err = rows.Scan(&exists); err != nil {
return
}
if !exists {
logger.Info("db", "op", "create_db", "db", "_internal.db")
if _, err = db.Exec(`
CREATE SCHEMA "_internal";
CREATE TABLE "_internal".db (
"key" varchar NOT NULL,
value varchar NULL,
CONSTRAINT db_pk PRIMARY KEY (key)
);
INSERT INTO _internal.db("key", "value")
VALUES('schema', '0');
`,
); err != nil {
return
}
}
return
} // }}}
func dbUpdate() (err error) { // {{{
/* Current schema revision is read from database.
* Used to iterate through the embedded SQL updates
* up to the db schema version currently compiled
* program is made for. */
var rows *sqlx.Rows
var schemaStr string
var schema int
rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return errors.New("Table _interval.db missing schema row")
}
if err = rows.Scan(&schemaStr); err != nil {
return
}
// Run updates
schema, err = strconv.Atoi(schemaStr)
if err != nil {
return err
}
sqlSchemaVersion := sqlSchema()
for i := (schema + 1); i <= sqlSchemaVersion; i++ {
logger.Info("db", "op", "upgrade_schema", "schema", i)
sql, _ := embedded.ReadFile(
fmt.Sprintf("sql/%04d.sql", i),
)
_, err = db.Exec(string(sql))
if err != nil {
return
}
_, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i)
if err != nil {
return
}
logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok")
}
return
} // }}}
func sqlSchema() (max int) { // {{{
var num int
files, _ := fs.ReadDir(embedded, "sql")
sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`)
for _, file := range files {
fname := sqlFilename.FindStringSubmatch(file.Name())
if len(fname) != 2 {
continue
}
num, _ = strconv.Atoi(fname[1])
}
if num > max {
max = num
}
return
} // }}}
// vim: foldmethod=marker

14
file.go
View File

@ -5,7 +5,6 @@ import (
"github.com/jmoiron/sqlx"
// Standard
"fmt"
"time"
)
@ -20,11 +19,11 @@ type File struct {
Uploaded time.Time
}
func (session Session) AddFile(file *File) (err error) { // {{{
file.UserID = session.UserID
func AddFile(userID int, file *File) (err error) { // {{{
file.UserID = userID
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
INSERT INTO file(user_id, node_id, filename, size, mime, md5)
VALUES($1, $2, $3, $4, $5, $6)
RETURNING id
@ -43,12 +42,11 @@ func (session Session) AddFile(file *File) (err error) { // {{{
rows.Next()
err = rows.Scan(&file.ID)
fmt.Printf("%#v\n", file)
return
} // }}}
func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{
func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
var rows *sqlx.Rows
rows, err = db.Queryx(
rows, err = service.Db.Conn.Queryx(
`SELECT *
FROM file
WHERE
@ -58,7 +56,7 @@ func (session Session) Files(nodeID, fileID int) (files []File, err error) { //
WHEN 0 THEN true
ELSE id = $3
END`,
session.UserID,
userID,
nodeID,
fileID,
)

16
key.go
View File

@ -3,8 +3,6 @@ package main
import (
// External
"github.com/jmoiron/sqlx"
// Standard
)
type Key struct {
@ -14,9 +12,9 @@ type Key struct {
Key string
}
func (session Session) Keys() (keys []Key, err error) {// {{{
func Keys(userID int) (keys []Key, err error) { // {{{
var rows *sqlx.Rows
if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil {
if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil {
return
}
defer rows.Close()
@ -32,13 +30,13 @@ func (session Session) Keys() (keys []Key, err error) {// {{{
return
} // }}}
func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{
func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{
var row *sqlx.Rows
if row, err = db.Queryx(
if row, err = service.Db.Conn.Queryx(
`INSERT INTO crypto_key(user_id, description, key)
VALUES($1, $2, $3)
RETURNING *`,
session.UserID,
userID,
description,
keyEncoded,
); err != nil {
@ -55,9 +53,9 @@ func (session Session) KeyCreate(description, keyEncoded string) (key Key, err e
return
} // }}}
func (session Session) KeyCounter() (counter int64, err error) {// {{{
func KeyCounter() (counter int64, err error) { // {{{
var rows *sqlx.Rows
rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
if err != nil {
return
}

346
main.go
View File

@ -1,20 +1,23 @@
package main
import (
// External
"git.gibonuddevalla.se/go/webservice"
// Internal
"git.gibonuddevalla.se/go/webservice/session"
// Standard
"crypto/md5"
"embed"
"encoding/hex"
"flag"
"fmt"
"html/template"
"io"
"log/slog"
"net/http"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
)
@ -25,8 +28,10 @@ var (
flagPort int
flagVersion bool
flagCreateUser bool
flagCheckLocal bool
flagConfig string
service *webservice.Service
connectionManager ConnectionManager
static http.Handler
config Config
@ -34,11 +39,24 @@ var (
VERSION string
//go:embed version sql/*
embedded embed.FS
embeddedSQL embed.FS
//go:embed static
staticFS embed.FS
)
func sqlProvider(dbname string, version int) (sql []byte, found bool) {
var err error
sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version))
if err != nil {
return
}
found = true
return
}
func init() { // {{{
version, _ := embedded.ReadFile("version")
version, _ := embeddedSQL.ReadFile("version")
VERSION = strings.TrimSpace(string(version))
opt := slog.HandlerOptions{}
@ -48,6 +66,7 @@ func init() { // {{{
flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on")
flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists")
flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit")
flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded")
flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file")
flag.Parse()
} // }}}
@ -58,56 +77,47 @@ func main() { // {{{
fmt.Printf("%s\n", VERSION)
os.Exit(0)
}
logger.Info("application", "version", VERSION)
config, err = ConfigRead(flagConfig)
service, err = webservice.New(flagConfig, VERSION)
if err != nil {
logger.Error("config", "error", err)
logger.Error("application", "error", err)
os.Exit(1)
}
service.SetDatabase(sqlProvider)
service.SetStaticDirectory("static", true)
service.SetStaticFS(staticFS, "static")
if err = dbInit(); err != nil {
logger.Error("db", "error", err)
os.Exit(1)
}
service.Register("/node/upload", true, true, nodeUpload)
service.Register("/node/tree", true, true, nodeTree)
service.Register("/node/retrieve", true, true, nodeRetrieve)
service.Register("/node/create", true, true, nodeCreate)
service.Register("/node/update", true, true, nodeUpdate)
service.Register("/node/rename", true, true, nodeRename)
service.Register("/node/delete", true, true, nodeDelete)
service.Register("/node/download", true, true, nodeDownload)
service.Register("/node/search", true, true, nodeSearch)
service.Register("/key/retrieve", true, true, keyRetrieve)
service.Register("/key/create", true, true, keyCreate)
service.Register("/key/counter", true, true, keyCounter)
service.Register("/ws", false, false, service.WebsocketHandler)
service.Register("/", false, false, service.StaticHandler)
if flagCreateUser {
err = createUser()
if err != nil {
logger.Error("db", "error", err)
os.Exit(1)
}
service.CreateUserPrompt()
os.Exit(0)
}
err = service.Start()
if err != nil {
logger.Error("webserver", "error", err)
os.Exit(1)
}
connectionManager = NewConnectionManager()
go connectionManager.BroadcastLoop()
static = http.FileServer(http.Dir(config.Application.Directories.Static))
http.HandleFunc("/css_updated", cssUpdateHandler)
http.HandleFunc("/session/create", sessionCreate)
http.HandleFunc("/session/retrieve", sessionRetrieve)
http.HandleFunc("/session/authenticate", sessionAuthenticate)
http.HandleFunc("/user/password", userPassword)
http.HandleFunc("/node/tree", nodeTree)
http.HandleFunc("/node/retrieve", nodeRetrieve)
http.HandleFunc("/node/create", nodeCreate)
http.HandleFunc("/node/update", nodeUpdate)
http.HandleFunc("/node/rename", nodeRename)
http.HandleFunc("/node/delete", nodeDelete)
http.HandleFunc("/node/upload", nodeUpload)
http.HandleFunc("/node/download", nodeDownload)
http.HandleFunc("/node/search", nodeSearch)
http.HandleFunc("/key/retrieve", keyRetrieve)
http.HandleFunc("/key/create", keyCreate)
http.HandleFunc("/key/counter", keyCounter)
http.HandleFunc("/ws", websocketHandler)
http.HandleFunc("/", staticHandler)
listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort)
logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains)
http.ListenAndServe(listen, nil)
} // }}}
func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{
@ -127,108 +137,8 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{
return
}
} // }}}
func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{
data := struct {
VERSION string
}{
VERSION: VERSION,
}
// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
// To get rid of problems with cached content in browser on a new version release,
// while also not disabling cache altogether.
logger.Debug("webserver", "request", r.URL.Path)
if r.URL.Path == "/favicon.ico" {
static.ServeHTTP(w, r)
return
}
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
static.ServeHTTP(w, r)
return
}
// Everything else is run through the template system.
// For now to get VERSION into files to fix caching.
logger.Debug("webserver", "template", r.URL.Path)
tmpl, err := newTemplate(r.URL.Path)
if err != nil {
if os.IsNotExist(err) {
w.WriteHeader(404)
}
w.Write([]byte(err.Error()))
return
}
if err = tmpl.Execute(w, data); err != nil {
w.Write([]byte(err.Error()))
}
} // }}}
func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/create")
session, err := CreateSession()
if err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Session": session,
})
} // }}}
func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/retrieve")
var err error
var found bool
var session Session
if session, found, err = ValidateSession(r, false); err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Valid": found,
"Session": session,
})
} // }}}
func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/authenticate")
var err error
var session Session
var authenticated bool
// Validate session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Username string
Password string
}{}
if err = parseRequest(r, &req); err != nil {
responseError(w, err)
return
}
if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Authenticated": authenticated,
"Session": session,
})
} // }}}
/*
func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
var err error
var ok bool
@ -259,15 +169,11 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
"CurrentPasswordOK": ok,
})
} // }}}
*/
func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/tree")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct{ StartNodeID int }{}
if err = parseRequest(r, &req); err != nil {
@ -275,7 +181,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
return
}
nodes, err := session.NodeTree(req.StartNodeID)
nodes, err := NodeTree(sess.UserID, req.StartNodeID)
if err != nil {
responseError(w, err)
return
@ -286,15 +192,9 @@ func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
"Nodes": nodes,
})
} // }}}
func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/retrieve")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct{ ID int }{}
if err = parseRequest(r, &req); err != nil {
@ -302,7 +202,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
return
}
node, err := session.Node(req.ID)
node, err := RetrieveNode(sess.UserID, req.ID)
if err != nil {
responseError(w, err)
return
@ -313,15 +213,9 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
"Node": node,
})
} // }}}
func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/create")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Name string
@ -332,7 +226,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
return
}
node, err := session.CreateNode(req.ParentID, req.Name)
node, err := CreateNode(sess.UserID, req.ParentID, req.Name)
if err != nil {
responseError(w, err)
return
@ -343,15 +237,9 @@ func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
"Node": node,
})
} // }}}
func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/update")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
@ -363,7 +251,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
return
}
err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID)
err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID)
if err != nil {
responseError(w, err)
return
@ -373,16 +261,9 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
"OK": true,
})
} // }}}
func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/rename")
func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
logger.Info("webserver", "request", "/node/rename")
req := struct {
NodeID int
@ -393,7 +274,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
return
}
err = session.RenameNode(req.NodeID, req.Name)
err = RenameNode(sess.UserID, req.NodeID, req.Name)
if err != nil {
responseError(w, err)
return
@ -403,16 +284,9 @@ func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
"OK": true,
})
} // }}}
func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/delete")
func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
logger.Info("webserver", "request", "/node/delete")
req := struct {
NodeID int
@ -422,7 +296,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
return
}
err = session.DeleteNode(req.NodeID)
err = DeleteNode(sess.UserID, req.NodeID)
if err != nil {
responseError(w, err)
return
@ -432,15 +306,9 @@ func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
"OK": true,
})
} // }}}
func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/upload")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
// Parse our multipart form, 10 << 20 specifies a maximum
// upload of 10 MB files.
@ -479,7 +347,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
MIME: handler.Header.Get("Content-Type"),
MD5: md5sum,
}
if err = session.AddFile(&nodeFile); err != nil {
if err = AddFile(sess.UserID, &nodeFile); err != nil {
responseError(w, err)
return
}
@ -516,17 +384,11 @@ func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
"File": nodeFile,
})
} // }}}
func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/download")
var err error
var session Session
var files []File
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
FileID int
@ -536,7 +398,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
return
}
files, err = session.Files(req.NodeID, req.FileID)
files, err = Files(sess.UserID, req.NodeID, req.FileID)
if err != nil {
responseError(w, err)
return
@ -583,17 +445,11 @@ func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
}
} // }}}
func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/files")
var err error
var session Session
var files []File
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
}{}
@ -602,7 +458,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
return
}
files, err = session.Files(req.NodeID, 0)
files, err = Files(sess.UserID, req.NodeID, 0)
if err != nil {
responseError(w, err)
return
@ -613,17 +469,11 @@ func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
"Files": files,
})
} // }}}
func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/search")
var err error
var session Session
var nodes []Node
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Search string
}{}
@ -632,7 +482,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
return
}
nodes, err = session.SearchNodes(req.Search)
nodes, err = SearchNodes(sess.UserID, req.Search)
if err != nil {
responseError(w, err)
return
@ -644,17 +494,11 @@ func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
})
} // }}}
func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/key/retrieve")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
keys, err := session.Keys()
keys, err := Keys(sess.UserID)
if err != nil {
responseError(w, err)
return
@ -665,15 +509,9 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
"Keys": keys,
})
} // }}}
func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/key/create")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Description string
@ -684,7 +522,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
return
}
key, err := session.KeyCreate(req.Description, req.Key)
key, err := KeyCreate(sess.UserID, req.Description, req.Key)
if err != nil {
responseError(w, err)
return
@ -695,17 +533,11 @@ func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
"Key": key,
})
} // }}}
func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{
func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/key/counter")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
counter, err := session.KeyCounter()
counter, err := KeyCounter()
if err != nil {
responseError(w, err)
return
@ -718,20 +550,4 @@ func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{
})
} // }}}
func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
// Append index.html if needed for further reading of the file
p := requestPath
if p[len(p)-1] == '/' {
p += "index.html"
}
p = config.Application.Directories.Static + p
base := path.Base(p)
if tmpl, err = template.New(base).ParseFiles(p); err != nil {
return
}
return
} // }}}
// vim: foldmethod=marker

74
node.go
View File

@ -25,9 +25,9 @@ type Node struct {
ContentEncrypted string `db:"content_encrypted" json:"-"`
}
func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{
func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
WITH RECURSIVE nodetree AS (
SELECT
*,
@ -62,7 +62,7 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {
ORDER BY
path ASC
`,
session.UserID,
userID,
startNodeID,
)
if err != nil {
@ -79,6 +79,9 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {
for rows.Next() {
node := Node{}
node.Complete = false
node.Crumbs = []Node{}
node.Children = []Node{}
node.Files = []File{}
if err = rows.StructScan(&node); err != nil {
return
}
@ -87,9 +90,9 @@ func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {
return
}// }}}
func (session Session) RootNode() (node Node, err error) {// {{{
func RootNode(userID int) (node Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
SELECT
id,
user_id,
@ -100,7 +103,7 @@ func (session Session) RootNode() (node Node, err error) {// {{{
user_id = $1 AND
parent_id IS NULL
`,
session.UserID,
userID,
)
if err != nil {
return
@ -108,7 +111,7 @@ func (session Session) RootNode() (node Node, err error) {// {{{
defer rows.Close()
node.Name = "Start"
node.UserID = session.UserID
node.UserID = userID
node.Complete = true
node.Children = []Node{}
node.Crumbs = []Node{}
@ -129,13 +132,13 @@ func (session Session) RootNode() (node Node, err error) {// {{{
return
}// }}}
func (session Session) Node(nodeID int) (node Node, err error) {// {{{
func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
if nodeID == 0 {
return session.RootNode()
return RootNode(userID)
}
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
WITH RECURSIVE recurse AS (
SELECT
id,
@ -170,7 +173,7 @@ func (session Session) Node(nodeID int) (node Node, err error) {// {{{
SELECT * FROM recurse ORDER BY level ASC
`,
session.UserID,
userID,
nodeID,
)
if err != nil {
@ -217,14 +220,14 @@ func (session Session) Node(nodeID int) (node Node, err error) {// {{{
}
}
node.Crumbs, err = session.NodeCrumbs(node.ID)
node.Files, err = session.Files(node.ID, 0)
node.Crumbs, err = NodeCrumbs(node.ID)
node.Files, err = Files(userID, node.ID, 0)
return
}// }}}
func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
WITH RECURSIVE nodes AS (
SELECT
id,
@ -260,10 +263,10 @@ func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
}
return
}// }}}
func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{
func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
INSERT INTO node(user_id, parent_id, name)
VALUES($1, NULLIF($2, 0)::integer, $3)
RETURNING
@ -273,7 +276,7 @@ func (session Session) CreateNode(parentID int, name string) (node Node, err err
name,
content
`,
session.UserID,
userID,
parentID,
name,
)
@ -292,12 +295,12 @@ func (session Session) CreateNode(parentID int, name string) (node Node, err err
node.Complete = true
}
node.Crumbs, err = session.NodeCrumbs(node.ID)
node.Crumbs, err = NodeCrumbs(node.ID)
return
}// }}}
func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{
func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{
if cryptoKeyID > 0 {
_, err = db.Exec(`
_, err = service.Db.Conn.Exec(`
UPDATE node
SET
content = '',
@ -313,10 +316,10 @@ func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (
content,
cryptoKeyID,
nodeID,
session.UserID,
userID,
)
} else {
_, err = db.Exec(`
_, err = service.Db.Conn.Exec(`
UPDATE node
SET
content = $1,
@ -332,24 +335,24 @@ func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (
content,
cryptoKeyID,
nodeID,
session.UserID,
userID,
)
}
return
}// }}}
func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{
_, err = db.Exec(`
func RenameNode(userID, nodeID int, name string) (err error) {// {{{
_, err = service.Db.Conn.Exec(`
UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3
`,
name,
session.UserID,
userID,
nodeID,
)
return
}// }}}
func (session Session) DeleteNode(nodeID int) (err error) {// {{{
_, err = db.Exec(`
func DeleteNode(userID, nodeID int) (err error) {// {{{
_, err = service.Db.Conn.Exec(`
WITH RECURSIVE nodetree AS (
SELECT
id, parent_id
@ -368,15 +371,15 @@ func (session Session) DeleteNode(nodeID int) (err error) {// {{{
DELETE FROM node WHERE id IN (
SELECT id FROM nodetree
)`,
session.UserID,
userID,
nodeID,
)
return
}// }}}
func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{
func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
nodes = []Node{}
var rows *sqlx.Rows
rows, err = db.Queryx(`
rows, err = service.Db.Conn.Queryx(`
SELECT
id,
user_id,
@ -385,14 +388,15 @@ func (session Session) SearchNodes(search string) (nodes []Node, err error) {//
updated
FROM node
WHERE
user_id = $1 AND
crypto_key_id IS NULL AND
(
content ~* $1 OR
name ~* $1
content ~* $2 OR
name ~* $2
)
ORDER BY
updated DESC
`, search)
`, userID, search)
if err != nil {
return
}

View File

@ -1,119 +0,0 @@
package main
import (
// Standard
"database/sql"
"errors"
"fmt"
"net/http"
"time"
)
type Session struct {
UUID string
UserID int
Created time.Time
}
func CreateSession() (session Session, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
INSERT INTO public.session(uuid)
VALUES(gen_random_uuid())
RETURNING uuid, created`,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&session.UUID, &session.Created)
}
return
}// }}}
func sessionUUID(r *http.Request) (string, error) {// {{{
headers := r.Header["X-Session-Id"]
if len(headers) > 0 {
return headers[0], nil
}
return "", errors.New("Invalid session")
}// }}}
func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{
var uuid string
if uuid, err = sessionUUID(r); err != nil {
return
}
session.UUID = uuid
if found, err = session.Retrieve(); err != nil {
return
}
if notFoundIsError && !found {
err = errors.New("Invalid session")
return
}
return
}// }}}
func (session *Session) Retrieve() (found bool, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
SELECT
uuid, user_id, created
FROM public.session
WHERE
uuid = $1 AND
created + $2::interval >= NOW()
`,
session.UUID,
fmt.Sprintf("%d days", config.Session.DaysValid),
); err != nil {
return
}
defer rows.Close()
found = false
if rows.Next() {
found = true
rows.Scan(&session.UUID, &session.UserID, &session.Created)
}
return
}// }}}
func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
SELECT id
FROM public.user
WHERE
username=$1 AND
password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`,
username,
password,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&session.UserID)
authenticated = session.UserID > 0
}
if authenticated {
_, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID)
if err != nil {
return
}
}
return
}// }}}
// vim: foldmethod=marker

View File

@ -1,34 +0,0 @@
/* Required for the gen_random_bytes function */
CREATE EXTENSION pgcrypto;
CREATE FUNCTION password_hash(salt_hex char(32), pass bytea)
RETURNS char(96)
LANGUAGE plpgsql
AS
$$
BEGIN
RETURN (
SELECT
salt_hex ||
encode(
sha256(
decode(salt_hex, 'hex') || /* salt in binary */
pass /* password */
),
'hex'
)
);
END;
$$;
/* Password has to be able to accommodate 96 characters instead of previous 64.
* It can't be char(96), because then the password would be padded to 96 characters. */
ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar;
/* Update all users with salted and hashed passwords */
UPDATE public.user
SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea);
/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash,
* and the varchar type is not longer necessary. */
ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar;

View File

@ -13,8 +13,8 @@ class App extends Component {
this.websocket = null
this.websocket_int_ping = null
this.websocket_int_reconnect = null
this.wsConnect()
this.wsLoop()
//this.wsConnect() // XXX
//this.wsLoop() // XXX
this.session = new Session(this)
this.session.initialize()
@ -114,7 +114,7 @@ class App extends Component {
}
if(this.session.UUID !== '')
headers['X-Session-Id'] = this.session.UUID
headers['X-Session-ID'] = this.session.UUID
fetch(url, {
method: 'POST',
@ -201,9 +201,24 @@ class Tree extends Component {
this.selectedTreeNode = null
this.props.app.tree = this
this.retrieve()
}//}}}
render({ app }) {//{{{
let renderedTreeTrunk = this.treeTrunk.map(node=>{
this.treeNodeComponents[node.ID] = createRef()
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
})
return html`<div id="tree">${renderedTreeTrunk}</div>`
}//}}}
retrieve(callback = null) {//{{{
this.props.app.request('/node/tree', { StartNodeID: 0 })
.then(res=>{
this.treeNodes = {}
this.treeNodeComponents = {}
this.treeTrunk = []
this.selectedTreeNode = null
// A tree of nodes is built. This requires the list of nodes
// returned from the server to be sorted in such a way that
// a parent node always appears before a child node.
@ -235,17 +250,12 @@ class Tree extends Component {
this.crumbsUpdateNodes()
this.forceUpdate()
if(callback)
callback()
})
.catch(this.responseError)
}//}}}
render({ app }) {//{{{
let renderedTreeTrunk = this.treeTrunk.map(node=>{
this.treeNodeComponents[node.ID] = createRef()
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
})
return html`<div id="tree">${renderedTreeTrunk}</div>`
}//}}}
setSelected(node) {//{{{
if(this.selectedTreeNode)
this.selectedTreeNode.selected.value = false

View File

@ -215,7 +215,14 @@ export class NodeUI extends Component {
let name = prompt("Name")
if(!name)
return
this.node.value.create(name, nodeID=>this.goToNode(nodeID))
this.node.value.create(name, nodeID=>{
console.log('before', this.props.app.startNode)
this.props.app.startNode = new Node(this.props.app, nodeID)
console.log('after', this.props.app.startNode)
this.props.app.tree.retrieve(()=>{
this.goToNode(nodeID)
})
})
}//}}}
saveNode() {//{{{
let nodeContent = this.nodeContent.current

View File

@ -1,16 +1,15 @@
export class Session {
constructor(app) {
constructor(app) {//{{{
this.app = app
this.UUID = ''
this.initialized = false
this.UserID = 0
}
}//}}}
initialize() {//{{{
// Retrieving the stored session UUID, if any.
// If one found, validate with server.
// If the browser doesn't know anything about a session,
// a call to /session/create is necessary to retrieve a session UUID.
let uuid = window.localStorage.getItem("session.UUID")
@ -25,9 +24,9 @@ export class Session {
// A call to /session/retrieve with a session UUID validates that the
// session is still valid and returns all session information.
this.UUID = uuid
this.app.request('/session/retrieve', {})
this.app.request('/_session/retrieve', {})
.then(res => {
if(res.Valid) {
if (res.Error === undefined) {
// Session exists on server.
// Not necessarily authenticated.
this.UserID = res.Session.UserID // could be 0
@ -41,10 +40,10 @@ export class Session {
.catch(this.app.responseError)
}//}}}
create() {//{{{
this.app.request('/session/create', {})
this.app.request('/_session/new', {})
.then(res => {
this.UUID = res.Session.UUID
window.localStorage.setItem('session.UUID', this.UUID)
window.localStorage.setItem('session.UUID', this.Session.UUID)
this.initialized = true
this.app.forceUpdate()
})
@ -53,13 +52,13 @@ export class Session {
authenticate(username, password) {//{{{
this.app.login.current.authentication_failed.value = false
this.app.request('/session/authenticate', {
this.app.request('/_session/authenticate', {
username,
password,
})
.then(res => {
if (res.Authenticated) {
this.UserID = res.Session.UserID
this.UserID = res.UserID
this.app.forceUpdate()
} else {
this.app.login.current.authentication_failed.value = true

47
user.go
View File

@ -1,11 +1,6 @@
package main
import (
// Standard
"database/sql"
"fmt"
)
/*
func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) {
var result sql.Result
var rowsAffected int64
@ -14,10 +9,10 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
UPDATE public.user
SET
password = password_hash(
/* salt in hex */
/ salt in hex /
ENCODE(gen_random_bytes(16), 'hex'),
/* password */
/ password /
$1::bytea
)
WHERE
@ -36,38 +31,4 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
return rowsAffected > 0, nil
}
func createUser() (err error) {
var username, password string
fmt.Printf("Username: ")
fmt.Scanln(&username)
fmt.Printf("Password: ")
fmt.Scanln(&password)
err = CreateDbUser(username, password)
return
}
func CreateDbUser(username string, password string) (err error) {
var result sql.Result
result, err = db.Exec(`
INSERT INTO public.user(username, password)
VALUES(
$1,
password_hash(
/* salt in hex */
ENCODE(gen_random_bytes(16), 'hex'),
/* password */
$2::bytea
)
)
`,
username,
password,
)
_, err = result.RowsAffected()
return
}
*/