Compare commits

..

No commits in common. "afa113b8b7066c8db589434fe5fe53f6d34cac3c" and "abbd320b93b849b4bbaa212a17e4c3c8b27f18ef" have entirely different histories.

23 changed files with 715 additions and 205 deletions

150
db.go Normal file
View File

@ -0,0 +1,150 @@
package main
import (
// External
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
// Standard
"errors"
"fmt"
"io/fs"
"regexp"
"strconv"
)
var (
dbConn string
db *sqlx.DB
)
func dbInit() (err error) { // {{{
dbConn = fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
config.Database.Host,
config.Database.Port,
config.Database.Username,
config.Database.Password,
config.Database.Name,
)
logger.Info("db", "op", "connect", "host", config.Database.Host, "port", config.Database.Port)
if db, err = sqlx.Connect("postgres", dbConn); err != nil {
return
}
if err = dbVerifyInternals(); err != nil {
return
}
err = dbUpdate()
return
} // }}}
func dbVerifyInternals() (err error) { // {{{
var rows *sqlx.Rows
if rows, err = db.Queryx(
`SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '_internal'
AND c.relname = 'db'
)`,
); err != nil {
return
}
defer rows.Close()
var exists bool
rows.Next()
if err = rows.Scan(&exists); err != nil {
return
}
if !exists {
logger.Info("db", "op", "create_db", "db", "_internal.db")
if _, err = db.Exec(`
CREATE SCHEMA "_internal";
CREATE TABLE "_internal".db (
"key" varchar NOT NULL,
value varchar NULL,
CONSTRAINT db_pk PRIMARY KEY (key)
);
INSERT INTO _internal.db("key", "value")
VALUES('schema', '0');
`,
); err != nil {
return
}
}
return
} // }}}
func dbUpdate() (err error) { // {{{
/* Current schema revision is read from database.
* Used to iterate through the embedded SQL updates
* up to the db schema version currently compiled
* program is made for. */
var rows *sqlx.Rows
var schemaStr string
var schema int
rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return errors.New("Table _interval.db missing schema row")
}
if err = rows.Scan(&schemaStr); err != nil {
return
}
// Run updates
schema, err = strconv.Atoi(schemaStr)
if err != nil {
return err
}
sqlSchemaVersion := sqlSchema()
for i := (schema + 1); i <= sqlSchemaVersion; i++ {
logger.Info("db", "op", "upgrade_schema", "schema", i)
sql, _ := embedded.ReadFile(
fmt.Sprintf("sql/%04d.sql", i),
)
_, err = db.Exec(string(sql))
if err != nil {
return
}
_, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i)
if err != nil {
return
}
logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok")
}
return
} // }}}
func sqlSchema() (max int) { // {{{
var num int
files, _ := fs.ReadDir(embedded, "sql")
sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`)
for _, file := range files {
fname := sqlFilename.FindStringSubmatch(file.Name())
if len(fname) != 2 {
continue
}
num, _ = strconv.Atoi(fname[1])
}
if num > max {
max = num
}
return
} // }}}
// vim: foldmethod=marker

14
file.go
View File

@ -5,6 +5,7 @@ import (
"github.com/jmoiron/sqlx"
// Standard
"fmt"
"time"
)
@ -19,11 +20,11 @@ type File struct {
Uploaded time.Time
}
func AddFile(userID int, file *File) (err error) { // {{{
file.UserID = userID
func (session Session) AddFile(file *File) (err error) { // {{{
file.UserID = session.UserID
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
INSERT INTO file(user_id, node_id, filename, size, mime, md5)
VALUES($1, $2, $3, $4, $5, $6)
RETURNING id
@ -42,11 +43,12 @@ func AddFile(userID int, file *File) (err error) { // {{{
rows.Next()
err = rows.Scan(&file.ID)
fmt.Printf("%#v\n", file)
return
} // }}}
func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(
rows, err = db.Queryx(
`SELECT *
FROM file
WHERE
@ -56,7 +58,7 @@ func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
WHEN 0 THEN true
ELSE id = $3
END`,
userID,
session.UserID,
nodeID,
fileID,
)

22
key.go
View File

@ -3,6 +3,8 @@ package main
import (
// External
"github.com/jmoiron/sqlx"
// Standard
)
type Key struct {
@ -12,9 +14,9 @@ type Key struct {
Key string
}
func Keys(userID int) (keys []Key, err error) { // {{{
func (session Session) Keys() (keys []Key, err error) {// {{{
var rows *sqlx.Rows
if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil {
if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil {
return
}
defer rows.Close()
@ -29,14 +31,14 @@ func Keys(userID int) (keys []Key, err error) { // {{{
}
return
} // }}}
func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{
}// }}}
func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{
var row *sqlx.Rows
if row, err = service.Db.Conn.Queryx(
if row, err = db.Queryx(
`INSERT INTO crypto_key(user_id, description, key)
VALUES($1, $2, $3)
RETURNING *`,
userID,
session.UserID,
description,
keyEncoded,
); err != nil {
@ -52,10 +54,10 @@ func KeyCreate(userID int, description, keyEncoded string) (key Key, err error)
}
return
} // }}}
func KeyCounter() (counter int64, err error) { // {{{
}// }}}
func (session Session) KeyCounter() (counter int64, err error) {// {{{
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
if err != nil {
return
}
@ -64,6 +66,6 @@ func KeyCounter() (counter int64, err error) { // {{{
rows.Next()
err = rows.Scan(&counter)
return
} // }}}
}// }}}
// vim: foldmethod=marker

346
main.go
View File

@ -1,23 +1,20 @@
package main
import (
// External
"git.gibonuddevalla.se/go/webservice"
// Internal
"git.gibonuddevalla.se/go/webservice/session"
// Standard
"crypto/md5"
"embed"
"encoding/hex"
"flag"
"fmt"
"html/template"
"io"
"log/slog"
"net/http"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
)
@ -28,10 +25,8 @@ var (
flagPort int
flagVersion bool
flagCreateUser bool
flagCheckLocal bool
flagConfig string
service *webservice.Service
connectionManager ConnectionManager
static http.Handler
config Config
@ -39,24 +34,11 @@ var (
VERSION string
//go:embed version sql/*
embeddedSQL embed.FS
//go:embed static
staticFS embed.FS
embedded embed.FS
)
func sqlProvider(dbname string, version int) (sql []byte, found bool) {
var err error
sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version))
if err != nil {
return
}
found = true
return
}
func init() { // {{{
version, _ := embeddedSQL.ReadFile("version")
version, _ := embedded.ReadFile("version")
VERSION = strings.TrimSpace(string(version))
opt := slog.HandlerOptions{}
@ -66,7 +48,6 @@ func init() { // {{{
flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on")
flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists")
flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit")
flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded")
flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file")
flag.Parse()
} // }}}
@ -77,47 +58,56 @@ func main() { // {{{
fmt.Printf("%s\n", VERSION)
os.Exit(0)
}
logger.Info("application", "version", VERSION)
service, err = webservice.New(flagConfig, VERSION)
config, err = ConfigRead(flagConfig)
if err != nil {
logger.Error("application", "error", err)
logger.Error("config", "error", err)
os.Exit(1)
}
service.SetDatabase(sqlProvider)
service.SetStaticDirectory("static", true)
service.SetStaticFS(staticFS, "static")
service.Register("/node/upload", true, true, nodeUpload)
service.Register("/node/tree", true, true, nodeTree)
service.Register("/node/retrieve", true, true, nodeRetrieve)
service.Register("/node/create", true, true, nodeCreate)
service.Register("/node/update", true, true, nodeUpdate)
service.Register("/node/rename", true, true, nodeRename)
service.Register("/node/delete", true, true, nodeDelete)
service.Register("/node/download", true, true, nodeDownload)
service.Register("/node/search", true, true, nodeSearch)
service.Register("/key/retrieve", true, true, keyRetrieve)
service.Register("/key/create", true, true, keyCreate)
service.Register("/key/counter", true, true, keyCounter)
service.Register("/ws", false, false, service.WebsocketHandler)
service.Register("/", false, false, service.StaticHandler)
if err = dbInit(); err != nil {
logger.Error("db", "error", err)
os.Exit(1)
}
if flagCreateUser {
service.CreateUserPrompt()
err = createUser()
if err != nil {
logger.Error("db", "error", err)
os.Exit(1)
}
os.Exit(0)
}
err = service.Start()
if err != nil {
logger.Error("webserver", "error", err)
os.Exit(1)
}
connectionManager = NewConnectionManager()
go connectionManager.BroadcastLoop()
static = http.FileServer(http.Dir(config.Application.Directories.Static))
http.HandleFunc("/css_updated", cssUpdateHandler)
http.HandleFunc("/session/create", sessionCreate)
http.HandleFunc("/session/retrieve", sessionRetrieve)
http.HandleFunc("/session/authenticate", sessionAuthenticate)
http.HandleFunc("/user/password", userPassword)
http.HandleFunc("/node/tree", nodeTree)
http.HandleFunc("/node/retrieve", nodeRetrieve)
http.HandleFunc("/node/create", nodeCreate)
http.HandleFunc("/node/update", nodeUpdate)
http.HandleFunc("/node/rename", nodeRename)
http.HandleFunc("/node/delete", nodeDelete)
http.HandleFunc("/node/upload", nodeUpload)
http.HandleFunc("/node/download", nodeDownload)
http.HandleFunc("/node/search", nodeSearch)
http.HandleFunc("/key/retrieve", keyRetrieve)
http.HandleFunc("/key/create", keyCreate)
http.HandleFunc("/key/counter", keyCounter)
http.HandleFunc("/ws", websocketHandler)
http.HandleFunc("/", staticHandler)
listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort)
logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains)
http.ListenAndServe(listen, nil)
} // }}}
func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{
@ -137,8 +127,108 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{
return
}
} // }}}
func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{
data := struct {
VERSION string
}{
VERSION: VERSION,
}
// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
// To get rid of problems with cached content in browser on a new version release,
// while also not disabling cache altogether.
logger.Debug("webserver", "request", r.URL.Path)
if r.URL.Path == "/favicon.ico" {
static.ServeHTTP(w, r)
return
}
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
static.ServeHTTP(w, r)
return
}
// Everything else is run through the template system.
// For now to get VERSION into files to fix caching.
logger.Debug("webserver", "template", r.URL.Path)
tmpl, err := newTemplate(r.URL.Path)
if err != nil {
if os.IsNotExist(err) {
w.WriteHeader(404)
}
w.Write([]byte(err.Error()))
return
}
if err = tmpl.Execute(w, data); err != nil {
w.Write([]byte(err.Error()))
}
} // }}}
func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/create")
session, err := CreateSession()
if err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Session": session,
})
} // }}}
func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/retrieve")
var err error
var found bool
var session Session
if session, found, err = ValidateSession(r, false); err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Valid": found,
"Session": session,
})
} // }}}
func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/session/authenticate")
var err error
var session Session
var authenticated bool
// Validate session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Username string
Password string
}{}
if err = parseRequest(r, &req); err != nil {
responseError(w, err)
return
}
if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil {
responseError(w, err)
return
}
responseData(w, map[string]interface{}{
"OK": true,
"Authenticated": authenticated,
"Session": session,
})
} // }}}
/*
func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
var err error
var ok bool
@ -169,11 +259,15 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
"CurrentPasswordOK": ok,
})
} // }}}
*/
func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
logger.Info("webserver", "request", "/node/tree")
func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct{ StartNodeID int }{}
if err = parseRequest(r, &req); err != nil {
@ -181,7 +275,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
return
}
nodes, err := NodeTree(sess.UserID, req.StartNodeID)
nodes, err := session.NodeTree(req.StartNodeID)
if err != nil {
responseError(w, err)
return
@ -192,9 +286,15 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
"Nodes": nodes,
})
} // }}}
func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/retrieve")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct{ ID int }{}
if err = parseRequest(r, &req); err != nil {
@ -202,7 +302,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { //
return
}
node, err := RetrieveNode(sess.UserID, req.ID)
node, err := session.Node(req.ID)
if err != nil {
responseError(w, err)
return
@ -213,9 +313,15 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { //
"Node": node,
})
} // }}}
func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/create")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Name string
@ -226,7 +332,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
return
}
node, err := CreateNode(sess.UserID, req.ParentID, req.Name)
node, err := session.CreateNode(req.ParentID, req.Name)
if err != nil {
responseError(w, err)
return
@ -237,9 +343,15 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
"Node": node,
})
} // }}}
func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/update")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
@ -251,7 +363,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
return
}
err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID)
err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID)
if err != nil {
responseError(w, err)
return
@ -261,10 +373,17 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
"OK": true,
})
} // }}}
func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
var err error
func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/rename")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
Name string
@ -274,7 +393,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
return
}
err = RenameNode(sess.UserID, req.NodeID, req.Name)
err = session.RenameNode(req.NodeID, req.Name)
if err != nil {
responseError(w, err)
return
@ -284,10 +403,17 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
"OK": true,
})
} // }}}
func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
var err error
func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/delete")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
}{}
@ -296,7 +422,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
return
}
err = DeleteNode(sess.UserID, req.NodeID)
err = session.DeleteNode(req.NodeID)
if err != nil {
responseError(w, err)
return
@ -306,9 +432,15 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
"OK": true,
})
} // }}}
func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/upload")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
// Parse our multipart form, 10 << 20 specifies a maximum
// upload of 10 MB files.
@ -347,7 +479,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
MIME: handler.Header.Get("Content-Type"),
MD5: md5sum,
}
if err = AddFile(sess.UserID, &nodeFile); err != nil {
if err = session.AddFile(&nodeFile); err != nil {
responseError(w, err)
return
}
@ -384,11 +516,17 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
"File": nodeFile,
})
} // }}}
func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/download")
var err error
var session Session
var files []File
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
FileID int
@ -398,7 +536,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { //
return
}
files, err = Files(sess.UserID, req.NodeID, req.FileID)
files, err = session.Files(req.NodeID, req.FileID)
if err != nil {
responseError(w, err)
return
@ -445,11 +583,17 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { //
}
} // }}}
func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/files")
var err error
var session Session
var files []File
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
NodeID int
}{}
@ -458,7 +602,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
return
}
files, err = Files(sess.UserID, req.NodeID, 0)
files, err = session.Files(req.NodeID, 0)
if err != nil {
responseError(w, err)
return
@ -469,11 +613,17 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
"Files": files,
})
} // }}}
func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/node/search")
var err error
var session Session
var nodes []Node
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Search string
}{}
@ -482,7 +632,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
return
}
nodes, err = SearchNodes(sess.UserID, req.Search)
nodes, err = session.SearchNodes(req.Search)
if err != nil {
responseError(w, err)
return
@ -494,11 +644,17 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
})
} // }}}
func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/key/retrieve")
var err error
var session Session
keys, err := Keys(sess.UserID)
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
keys, err := session.Keys()
if err != nil {
responseError(w, err)
return
@ -509,9 +665,15 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {
"Keys": keys,
})
} // }}}
func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/key/create")
var err error
var session Session
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
req := struct {
Description string
@ -522,7 +684,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
return
}
key, err := KeyCreate(sess.UserID, req.Description, req.Key)
key, err := session.KeyCreate(req.Description, req.Key)
if err != nil {
responseError(w, err)
return
@ -533,11 +695,17 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
"Key": key,
})
} // }}}
func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{
logger.Info("webserver", "request", "/key/counter")
var err error
var session Session
counter, err := KeyCounter()
if session, _, err = ValidateSession(r, true); err != nil {
responseError(w, err)
return
}
counter, err := session.KeyCounter()
if err != nil {
responseError(w, err)
return
@ -550,4 +718,20 @@ func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
})
} // }}}
func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
// Append index.html if needed for further reading of the file
p := requestPath
if p[len(p)-1] == '/' {
p += "index.html"
}
p = config.Application.Directories.Static + p
base := path.Base(p)
if tmpl, err = template.New(base).ParseFiles(p); err != nil {
return
}
return
} // }}}
// vim: foldmethod=marker

74
node.go
View File

@ -25,9 +25,9 @@ type Node struct {
ContentEncrypted string `db:"content_encrypted" json:"-"`
}
func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
WITH RECURSIVE nodetree AS (
SELECT
*,
@ -62,7 +62,7 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
ORDER BY
path ASC
`,
userID,
session.UserID,
startNodeID,
)
if err != nil {
@ -79,9 +79,6 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
for rows.Next() {
node := Node{}
node.Complete = false
node.Crumbs = []Node{}
node.Children = []Node{}
node.Files = []File{}
if err = rows.StructScan(&node); err != nil {
return
}
@ -90,9 +87,9 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
return
}// }}}
func RootNode(userID int) (node Node, err error) {// {{{
func (session Session) RootNode() (node Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
SELECT
id,
user_id,
@ -103,7 +100,7 @@ func RootNode(userID int) (node Node, err error) {// {{{
user_id = $1 AND
parent_id IS NULL
`,
userID,
session.UserID,
)
if err != nil {
return
@ -111,7 +108,7 @@ func RootNode(userID int) (node Node, err error) {// {{{
defer rows.Close()
node.Name = "Start"
node.UserID = userID
node.UserID = session.UserID
node.Complete = true
node.Children = []Node{}
node.Crumbs = []Node{}
@ -132,13 +129,13 @@ func RootNode(userID int) (node Node, err error) {// {{{
return
}// }}}
func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
func (session Session) Node(nodeID int) (node Node, err error) {// {{{
if nodeID == 0 {
return RootNode(userID)
return session.RootNode()
}
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
WITH RECURSIVE recurse AS (
SELECT
id,
@ -173,7 +170,7 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
SELECT * FROM recurse ORDER BY level ASC
`,
userID,
session.UserID,
nodeID,
)
if err != nil {
@ -220,14 +217,14 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
}
}
node.Crumbs, err = NodeCrumbs(node.ID)
node.Files, err = Files(userID, node.ID, 0)
node.Crumbs, err = session.NodeCrumbs(node.ID)
node.Files, err = session.Files(node.ID, 0)
return
}// }}}
func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
WITH RECURSIVE nodes AS (
SELECT
id,
@ -263,10 +260,10 @@ func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
}
return
}// }}}
func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{
func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
INSERT INTO node(user_id, parent_id, name)
VALUES($1, NULLIF($2, 0)::integer, $3)
RETURNING
@ -276,7 +273,7 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{
name,
content
`,
userID,
session.UserID,
parentID,
name,
)
@ -295,12 +292,12 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{
node.Complete = true
}
node.Crumbs, err = NodeCrumbs(node.ID)
node.Crumbs, err = session.NodeCrumbs(node.ID)
return
}// }}}
func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{
func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{
if cryptoKeyID > 0 {
_, err = service.Db.Conn.Exec(`
_, err = db.Exec(`
UPDATE node
SET
content = '',
@ -316,10 +313,10 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error)
content,
cryptoKeyID,
nodeID,
userID,
session.UserID,
)
} else {
_, err = service.Db.Conn.Exec(`
_, err = db.Exec(`
UPDATE node
SET
content = $1,
@ -335,24 +332,24 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error)
content,
cryptoKeyID,
nodeID,
userID,
session.UserID,
)
}
return
}// }}}
func RenameNode(userID, nodeID int, name string) (err error) {// {{{
_, err = service.Db.Conn.Exec(`
func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{
_, err = db.Exec(`
UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3
`,
name,
userID,
session.UserID,
nodeID,
)
return
}// }}}
func DeleteNode(userID, nodeID int) (err error) {// {{{
_, err = service.Db.Conn.Exec(`
func (session Session) DeleteNode(nodeID int) (err error) {// {{{
_, err = db.Exec(`
WITH RECURSIVE nodetree AS (
SELECT
id, parent_id
@ -371,15 +368,15 @@ func DeleteNode(userID, nodeID int) (err error) {// {{{
DELETE FROM node WHERE id IN (
SELECT id FROM nodetree
)`,
userID,
session.UserID,
nodeID,
)
return
}// }}}
func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{
nodes = []Node{}
var rows *sqlx.Rows
rows, err = service.Db.Conn.Queryx(`
rows, err = db.Queryx(`
SELECT
id,
user_id,
@ -388,15 +385,14 @@ func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
updated
FROM node
WHERE
user_id = $1 AND
crypto_key_id IS NULL AND
(
content ~* $2 OR
name ~* $2
content ~* $1 OR
name ~* $1
)
ORDER BY
updated DESC
`, userID, search)
`, search)
if err != nil {
return
}

119
session.go Normal file
View File

@ -0,0 +1,119 @@
package main
import (
// Standard
"database/sql"
"errors"
"fmt"
"net/http"
"time"
)
type Session struct {
UUID string
UserID int
Created time.Time
}
func CreateSession() (session Session, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
INSERT INTO public.session(uuid)
VALUES(gen_random_uuid())
RETURNING uuid, created`,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&session.UUID, &session.Created)
}
return
}// }}}
func sessionUUID(r *http.Request) (string, error) {// {{{
headers := r.Header["X-Session-Id"]
if len(headers) > 0 {
return headers[0], nil
}
return "", errors.New("Invalid session")
}// }}}
func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{
var uuid string
if uuid, err = sessionUUID(r); err != nil {
return
}
session.UUID = uuid
if found, err = session.Retrieve(); err != nil {
return
}
if notFoundIsError && !found {
err = errors.New("Invalid session")
return
}
return
}// }}}
func (session *Session) Retrieve() (found bool, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
SELECT
uuid, user_id, created
FROM public.session
WHERE
uuid = $1 AND
created + $2::interval >= NOW()
`,
session.UUID,
fmt.Sprintf("%d days", config.Session.DaysValid),
); err != nil {
return
}
defer rows.Close()
found = false
if rows.Next() {
found = true
rows.Scan(&session.UUID, &session.UserID, &session.Created)
}
return
}// }}}
func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{
var rows *sql.Rows
if rows, err = db.Query(`
SELECT id
FROM public.user
WHERE
username=$1 AND
password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`,
username,
password,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&session.UserID)
authenticated = session.UserID > 0
}
if authenticated {
_, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID)
if err != nil {
return
}
}
return
}// }}}
// vim: foldmethod=marker

34
sql/0013.sql Normal file
View File

@ -0,0 +1,34 @@
/* Required for the gen_random_bytes function */
CREATE EXTENSION pgcrypto;
CREATE FUNCTION password_hash(salt_hex char(32), pass bytea)
RETURNS char(96)
LANGUAGE plpgsql
AS
$$
BEGIN
RETURN (
SELECT
salt_hex ||
encode(
sha256(
decode(salt_hex, 'hex') || /* salt in binary */
pass /* password */
),
'hex'
)
);
END;
$$;
/* Password has to be able to accommodate 96 characters instead of previous 64.
* It can't be char(96), because then the password would be padded to 96 characters. */
ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar;
/* Update all users with salted and hashed passwords */
UPDATE public.user
SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea);
/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash,
* and the varchar type is not longer necessary. */
ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar;

View File

@ -13,8 +13,8 @@ class App extends Component {
this.websocket = null
this.websocket_int_ping = null
this.websocket_int_reconnect = null
//this.wsConnect() // XXX
//this.wsLoop() // XXX
this.wsConnect()
this.wsLoop()
this.session = new Session(this)
this.session.initialize()
@ -114,7 +114,7 @@ class App extends Component {
}
if(this.session.UUID !== '')
headers['X-Session-ID'] = this.session.UUID
headers['X-Session-Id'] = this.session.UUID
fetch(url, {
method: 'POST',
@ -201,24 +201,9 @@ class Tree extends Component {
this.selectedTreeNode = null
this.props.app.tree = this
this.retrieve()
}//}}}
render({ app }) {//{{{
let renderedTreeTrunk = this.treeTrunk.map(node=>{
this.treeNodeComponents[node.ID] = createRef()
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
})
return html`<div id="tree">${renderedTreeTrunk}</div>`
}//}}}
retrieve(callback = null) {//{{{
this.props.app.request('/node/tree', { StartNodeID: 0 })
.then(res=>{
this.treeNodes = {}
this.treeNodeComponents = {}
this.treeTrunk = []
this.selectedTreeNode = null
// A tree of nodes is built. This requires the list of nodes
// returned from the server to be sorted in such a way that
// a parent node always appears before a child node.
@ -250,12 +235,17 @@ class Tree extends Component {
this.crumbsUpdateNodes()
this.forceUpdate()
if(callback)
callback()
})
.catch(this.responseError)
}//}}}
render({ app }) {//{{{
let renderedTreeTrunk = this.treeTrunk.map(node=>{
this.treeNodeComponents[node.ID] = createRef()
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
})
return html`<div id="tree">${renderedTreeTrunk}</div>`
}//}}}
setSelected(node) {//{{{
if(this.selectedTreeNode)
this.selectedTreeNode.selected.value = false

View File

@ -215,14 +215,7 @@ export class NodeUI extends Component {
let name = prompt("Name")
if(!name)
return
this.node.value.create(name, nodeID=>{
console.log('before', this.props.app.startNode)
this.props.app.startNode = new Node(this.props.app, nodeID)
console.log('after', this.props.app.startNode)
this.props.app.tree.retrieve(()=>{
this.goToNode(nodeID)
})
})
this.node.value.create(name, nodeID=>this.goToNode(nodeID))
}//}}}
saveNode() {//{{{
let nodeContent = this.nodeContent.current

View File

@ -1,19 +1,20 @@
export class Session {
constructor(app) {//{{{
constructor(app) {
this.app = app
this.UUID = ''
this.initialized = false
this.UserID = 0
}//}}}
}
initialize() {//{{{
// Retrieving the stored session UUID, if any.
// If one found, validate with server.
// If the browser doesn't know anything about a session,
// a call to /session/create is necessary to retrieve a session UUID.
let uuid = window.localStorage.getItem("session.UUID")
if (uuid === null) {
let uuid= window.localStorage.getItem("session.UUID")
if(uuid === null) {
this.create()
return
}
@ -24,47 +25,47 @@ export class Session {
// A call to /session/retrieve with a session UUID validates that the
// session is still valid and returns all session information.
this.UUID = uuid
this.app.request('/_session/retrieve', {})
.then(res => {
if (res.Error === undefined) {
// Session exists on server.
// Not necessarily authenticated.
this.UserID = res.Session.UserID // could be 0
this.initialized = true
this.app.forceUpdate()
} else {
// Session has probably expired. A new is required.
this.create()
}
})
.catch(this.app.responseError)
}//}}}
create() {//{{{
this.app.request('/_session/new', {})
.then(res => {
this.UUID = res.Session.UUID
window.localStorage.setItem('session.UUID', this.Session.UUID)
this.app.request('/session/retrieve', {})
.then(res=>{
if(res.Valid) {
// Session exists on server.
// Not necessarily authenticated.
this.UserID = res.Session.UserID // could be 0
this.initialized = true
this.app.forceUpdate()
})
.catch(this.responseError)
} else {
// Session has probably expired. A new is required.
this.create()
}
})
.catch(this.app.responseError)
}//}}}
create() {//{{{
this.app.request('/session/create', {})
.then(res=>{
this.UUID = res.Session.UUID
window.localStorage.setItem('session.UUID', this.UUID)
this.initialized = true
this.app.forceUpdate()
})
.catch(this.responseError)
}//}}}
authenticate(username, password) {//{{{
this.app.login.current.authentication_failed.value = false
this.app.request('/_session/authenticate', {
this.app.request('/session/authenticate', {
username,
password,
})
.then(res => {
if (res.Authenticated) {
this.UserID = res.UserID
this.app.forceUpdate()
} else {
this.app.login.current.authentication_failed.value = true
}
})
.catch(this.app.responseError)
.then(res=>{
if(res.Authenticated) {
this.UserID = res.Session.UserID
this.app.forceUpdate()
} else {
this.app.login.current.authentication_failed.value = true
}
})
.catch(this.app.responseError)
}//}}}
authenticated() {//{{{
return this.UserID != 0

47
user.go
View File

@ -1,6 +1,11 @@
package main
/*
import (
// Standard
"database/sql"
"fmt"
)
func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) {
var result sql.Result
var rowsAffected int64
@ -9,10 +14,10 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
UPDATE public.user
SET
password = password_hash(
/ salt in hex /
/* salt in hex */
ENCODE(gen_random_bytes(16), 'hex'),
/ password /
/* password */
$1::bytea
)
WHERE
@ -31,4 +36,38 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
return rowsAffected > 0, nil
}
*/
func createUser() (err error) {
var username, password string
fmt.Printf("Username: ")
fmt.Scanln(&username)
fmt.Printf("Password: ")
fmt.Scanln(&password)
err = CreateDbUser(username, password)
return
}
func CreateDbUser(username string, password string) (err error) {
var result sql.Result
result, err = db.Exec(`
INSERT INTO public.user(username, password)
VALUES(
$1,
password_hash(
/* salt in hex */
ENCODE(gen_random_bytes(16), 'hex'),
/* password */
$2::bytea
)
)
`,
username,
password,
)
_, err = result.RowsAffected()
return
}