Compare commits
No commits in common. "afa113b8b7066c8db589434fe5fe53f6d34cac3c" and "abbd320b93b849b4bbaa212a17e4c3c8b27f18ef" have entirely different histories.
afa113b8b7
...
abbd320b93
150
db.go
Normal file
150
db.go
Normal file
@ -0,0 +1,150 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
// External
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
// Standard
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
dbConn string
|
||||
db *sqlx.DB
|
||||
)
|
||||
|
||||
func dbInit() (err error) { // {{{
|
||||
dbConn = fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
config.Database.Host,
|
||||
config.Database.Port,
|
||||
config.Database.Username,
|
||||
config.Database.Password,
|
||||
config.Database.Name,
|
||||
)
|
||||
|
||||
logger.Info("db", "op", "connect", "host", config.Database.Host, "port", config.Database.Port)
|
||||
|
||||
if db, err = sqlx.Connect("postgres", dbConn); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = dbVerifyInternals(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = dbUpdate()
|
||||
return
|
||||
} // }}}
|
||||
func dbVerifyInternals() (err error) { // {{{
|
||||
var rows *sqlx.Rows
|
||||
if rows, err = db.Queryx(
|
||||
`SELECT EXISTS (
|
||||
SELECT FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = '_internal'
|
||||
AND c.relname = 'db'
|
||||
)`,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var exists bool
|
||||
rows.Next()
|
||||
if err = rows.Scan(&exists); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !exists {
|
||||
logger.Info("db", "op", "create_db", "db", "_internal.db")
|
||||
if _, err = db.Exec(`
|
||||
CREATE SCHEMA "_internal";
|
||||
|
||||
CREATE TABLE "_internal".db (
|
||||
"key" varchar NOT NULL,
|
||||
value varchar NULL,
|
||||
CONSTRAINT db_pk PRIMARY KEY (key)
|
||||
);
|
||||
|
||||
INSERT INTO _internal.db("key", "value")
|
||||
VALUES('schema', '0');
|
||||
`,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
} // }}}
|
||||
func dbUpdate() (err error) { // {{{
|
||||
/* Current schema revision is read from database.
|
||||
* Used to iterate through the embedded SQL updates
|
||||
* up to the db schema version currently compiled
|
||||
* program is made for. */
|
||||
var rows *sqlx.Rows
|
||||
var schemaStr string
|
||||
var schema int
|
||||
rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return errors.New("Table _interval.db missing schema row")
|
||||
}
|
||||
|
||||
if err = rows.Scan(&schemaStr); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Run updates
|
||||
schema, err = strconv.Atoi(schemaStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sqlSchemaVersion := sqlSchema()
|
||||
for i := (schema + 1); i <= sqlSchemaVersion; i++ {
|
||||
logger.Info("db", "op", "upgrade_schema", "schema", i)
|
||||
sql, _ := embedded.ReadFile(
|
||||
fmt.Sprintf("sql/%04d.sql", i),
|
||||
)
|
||||
_, err = db.Exec(string(sql))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok")
|
||||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
func sqlSchema() (max int) { // {{{
|
||||
var num int
|
||||
|
||||
files, _ := fs.ReadDir(embedded, "sql")
|
||||
sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`)
|
||||
|
||||
for _, file := range files {
|
||||
fname := sqlFilename.FindStringSubmatch(file.Name())
|
||||
if len(fname) != 2 {
|
||||
continue
|
||||
}
|
||||
num, _ = strconv.Atoi(fname[1])
|
||||
}
|
||||
|
||||
if num > max {
|
||||
max = num
|
||||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
|
||||
// vim: foldmethod=marker
|
14
file.go
14
file.go
@ -5,6 +5,7 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
// Standard
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -19,11 +20,11 @@ type File struct {
|
||||
Uploaded time.Time
|
||||
}
|
||||
|
||||
func AddFile(userID int, file *File) (err error) { // {{{
|
||||
file.UserID = userID
|
||||
func (session Session) AddFile(file *File) (err error) { // {{{
|
||||
file.UserID = session.UserID
|
||||
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
INSERT INTO file(user_id, node_id, filename, size, mime, md5)
|
||||
VALUES($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id
|
||||
@ -42,11 +43,12 @@ func AddFile(userID int, file *File) (err error) { // {{{
|
||||
|
||||
rows.Next()
|
||||
err = rows.Scan(&file.ID)
|
||||
fmt.Printf("%#v\n", file)
|
||||
return
|
||||
} // }}}
|
||||
func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
|
||||
func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(
|
||||
rows, err = db.Queryx(
|
||||
`SELECT *
|
||||
FROM file
|
||||
WHERE
|
||||
@ -56,7 +58,7 @@ func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
|
||||
WHEN 0 THEN true
|
||||
ELSE id = $3
|
||||
END`,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
fileID,
|
||||
)
|
||||
|
16
key.go
16
key.go
@ -3,6 +3,8 @@ package main
|
||||
import (
|
||||
// External
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
// Standard
|
||||
)
|
||||
|
||||
type Key struct {
|
||||
@ -12,9 +14,9 @@ type Key struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func Keys(userID int) (keys []Key, err error) { // {{{
|
||||
func (session Session) Keys() (keys []Key, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil {
|
||||
if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
@ -30,13 +32,13 @@ func Keys(userID int) (keys []Key, err error) { // {{{
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{
|
||||
func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{
|
||||
var row *sqlx.Rows
|
||||
if row, err = service.Db.Conn.Queryx(
|
||||
if row, err = db.Queryx(
|
||||
`INSERT INTO crypto_key(user_id, description, key)
|
||||
VALUES($1, $2, $3)
|
||||
RETURNING *`,
|
||||
userID,
|
||||
session.UserID,
|
||||
description,
|
||||
keyEncoded,
|
||||
); err != nil {
|
||||
@ -53,9 +55,9 @@ func KeyCreate(userID int, description, keyEncoded string) (key Key, err error)
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func KeyCounter() (counter int64, err error) { // {{{
|
||||
func (session Session) KeyCounter() (counter int64, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
|
||||
rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
344
main.go
344
main.go
@ -1,23 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
// External
|
||||
"git.gibonuddevalla.se/go/webservice"
|
||||
|
||||
// Internal
|
||||
"git.gibonuddevalla.se/go/webservice/session"
|
||||
|
||||
// Standard
|
||||
"crypto/md5"
|
||||
"embed"
|
||||
"encoding/hex"
|
||||
"flag"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@ -28,10 +25,8 @@ var (
|
||||
flagPort int
|
||||
flagVersion bool
|
||||
flagCreateUser bool
|
||||
flagCheckLocal bool
|
||||
flagConfig string
|
||||
|
||||
service *webservice.Service
|
||||
connectionManager ConnectionManager
|
||||
static http.Handler
|
||||
config Config
|
||||
@ -39,24 +34,11 @@ var (
|
||||
VERSION string
|
||||
|
||||
//go:embed version sql/*
|
||||
embeddedSQL embed.FS
|
||||
|
||||
//go:embed static
|
||||
staticFS embed.FS
|
||||
embedded embed.FS
|
||||
)
|
||||
|
||||
func sqlProvider(dbname string, version int) (sql []byte, found bool) {
|
||||
var err error
|
||||
sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
found = true
|
||||
return
|
||||
}
|
||||
|
||||
func init() { // {{{
|
||||
version, _ := embeddedSQL.ReadFile("version")
|
||||
version, _ := embedded.ReadFile("version")
|
||||
VERSION = strings.TrimSpace(string(version))
|
||||
|
||||
opt := slog.HandlerOptions{}
|
||||
@ -66,7 +48,6 @@ func init() { // {{{
|
||||
flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on")
|
||||
flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists")
|
||||
flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit")
|
||||
flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded")
|
||||
flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file")
|
||||
flag.Parse()
|
||||
} // }}}
|
||||
@ -77,47 +58,56 @@ func main() { // {{{
|
||||
fmt.Printf("%s\n", VERSION)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
logger.Info("application", "version", VERSION)
|
||||
|
||||
service, err = webservice.New(flagConfig, VERSION)
|
||||
config, err = ConfigRead(flagConfig)
|
||||
if err != nil {
|
||||
logger.Error("application", "error", err)
|
||||
logger.Error("config", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
service.SetDatabase(sqlProvider)
|
||||
service.SetStaticDirectory("static", true)
|
||||
service.SetStaticFS(staticFS, "static")
|
||||
|
||||
service.Register("/node/upload", true, true, nodeUpload)
|
||||
service.Register("/node/tree", true, true, nodeTree)
|
||||
service.Register("/node/retrieve", true, true, nodeRetrieve)
|
||||
service.Register("/node/create", true, true, nodeCreate)
|
||||
service.Register("/node/update", true, true, nodeUpdate)
|
||||
service.Register("/node/rename", true, true, nodeRename)
|
||||
service.Register("/node/delete", true, true, nodeDelete)
|
||||
service.Register("/node/download", true, true, nodeDownload)
|
||||
service.Register("/node/search", true, true, nodeSearch)
|
||||
service.Register("/key/retrieve", true, true, keyRetrieve)
|
||||
service.Register("/key/create", true, true, keyCreate)
|
||||
service.Register("/key/counter", true, true, keyCounter)
|
||||
service.Register("/ws", false, false, service.WebsocketHandler)
|
||||
service.Register("/", false, false, service.StaticHandler)
|
||||
if err = dbInit(); err != nil {
|
||||
logger.Error("db", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if flagCreateUser {
|
||||
service.CreateUserPrompt()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err = service.Start()
|
||||
err = createUser()
|
||||
if err != nil {
|
||||
logger.Error("webserver", "error", err)
|
||||
logger.Error("db", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
connectionManager = NewConnectionManager()
|
||||
go connectionManager.BroadcastLoop()
|
||||
|
||||
static = http.FileServer(http.Dir(config.Application.Directories.Static))
|
||||
http.HandleFunc("/css_updated", cssUpdateHandler)
|
||||
http.HandleFunc("/session/create", sessionCreate)
|
||||
http.HandleFunc("/session/retrieve", sessionRetrieve)
|
||||
http.HandleFunc("/session/authenticate", sessionAuthenticate)
|
||||
http.HandleFunc("/user/password", userPassword)
|
||||
http.HandleFunc("/node/tree", nodeTree)
|
||||
http.HandleFunc("/node/retrieve", nodeRetrieve)
|
||||
http.HandleFunc("/node/create", nodeCreate)
|
||||
http.HandleFunc("/node/update", nodeUpdate)
|
||||
http.HandleFunc("/node/rename", nodeRename)
|
||||
http.HandleFunc("/node/delete", nodeDelete)
|
||||
http.HandleFunc("/node/upload", nodeUpload)
|
||||
http.HandleFunc("/node/download", nodeDownload)
|
||||
http.HandleFunc("/node/search", nodeSearch)
|
||||
http.HandleFunc("/key/retrieve", keyRetrieve)
|
||||
http.HandleFunc("/key/create", keyCreate)
|
||||
http.HandleFunc("/key/counter", keyCounter)
|
||||
http.HandleFunc("/ws", websocketHandler)
|
||||
http.HandleFunc("/", staticHandler)
|
||||
|
||||
listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort)
|
||||
logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains)
|
||||
http.ListenAndServe(listen, nil)
|
||||
} // }}}
|
||||
|
||||
func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
@ -137,8 +127,108 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
return
|
||||
}
|
||||
} // }}}
|
||||
func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
data := struct {
|
||||
VERSION string
|
||||
}{
|
||||
VERSION: VERSION,
|
||||
}
|
||||
|
||||
// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
|
||||
// To get rid of problems with cached content in browser on a new version release,
|
||||
// while also not disabling cache altogether.
|
||||
logger.Debug("webserver", "request", r.URL.Path)
|
||||
if r.URL.Path == "/favicon.ico" {
|
||||
static.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
|
||||
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
|
||||
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
|
||||
static.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Everything else is run through the template system.
|
||||
// For now to get VERSION into files to fix caching.
|
||||
logger.Debug("webserver", "template", r.URL.Path)
|
||||
tmpl, err := newTemplate(r.URL.Path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err = tmpl.Execute(w, data); err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
}
|
||||
} // }}}
|
||||
|
||||
func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/session/create")
|
||||
session, err := CreateSession()
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
responseData(w, map[string]interface{}{
|
||||
"OK": true,
|
||||
"Session": session,
|
||||
})
|
||||
} // }}}
|
||||
func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/session/retrieve")
|
||||
var err error
|
||||
var found bool
|
||||
var session Session
|
||||
|
||||
if session, found, err = ValidateSession(r, false); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
responseData(w, map[string]interface{}{
|
||||
"OK": true,
|
||||
"Valid": found,
|
||||
"Session": session,
|
||||
})
|
||||
} // }}}
|
||||
func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/session/authenticate")
|
||||
var err error
|
||||
var session Session
|
||||
var authenticated bool
|
||||
|
||||
// Validate session
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Username string
|
||||
Password string
|
||||
}{}
|
||||
if err = parseRequest(r, &req); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
responseData(w, map[string]interface{}{
|
||||
"OK": true,
|
||||
"Authenticated": authenticated,
|
||||
"Session": session,
|
||||
})
|
||||
} // }}}
|
||||
|
||||
/*
|
||||
func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
var err error
|
||||
var ok bool
|
||||
@ -169,11 +259,15 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
"CurrentPasswordOK": ok,
|
||||
})
|
||||
} // }}}
|
||||
*/
|
||||
|
||||
func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
logger.Info("webserver", "request", "/node/tree")
|
||||
func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct{ StartNodeID int }{}
|
||||
if err = parseRequest(r, &req); err != nil {
|
||||
@ -181,7 +275,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
return
|
||||
}
|
||||
|
||||
nodes, err := NodeTree(sess.UserID, req.StartNodeID)
|
||||
nodes, err := session.NodeTree(req.StartNodeID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -192,9 +286,15 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
"Nodes": nodes,
|
||||
})
|
||||
} // }}}
|
||||
func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/retrieve")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct{ ID int }{}
|
||||
if err = parseRequest(r, &req); err != nil {
|
||||
@ -202,7 +302,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
||||
return
|
||||
}
|
||||
|
||||
node, err := RetrieveNode(sess.UserID, req.ID)
|
||||
node, err := session.Node(req.ID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -213,9 +313,15 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
||||
"Node": node,
|
||||
})
|
||||
} // }}}
|
||||
func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/create")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Name string
|
||||
@ -226,7 +332,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
return
|
||||
}
|
||||
|
||||
node, err := CreateNode(sess.UserID, req.ParentID, req.Name)
|
||||
node, err := session.CreateNode(req.ParentID, req.Name)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -237,9 +343,15 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
"Node": node,
|
||||
})
|
||||
} // }}}
|
||||
func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/update")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
@ -251,7 +363,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
return
|
||||
}
|
||||
|
||||
err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID)
|
||||
err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -261,10 +373,17 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
"OK": true,
|
||||
})
|
||||
} // }}}
|
||||
func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
var err error
|
||||
func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/rename")
|
||||
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
Name string
|
||||
@ -274,7 +393,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
return
|
||||
}
|
||||
|
||||
err = RenameNode(sess.UserID, req.NodeID, req.Name)
|
||||
err = session.RenameNode(req.NodeID, req.Name)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -284,10 +403,17 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
"OK": true,
|
||||
})
|
||||
} // }}}
|
||||
func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
var err error
|
||||
func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/delete")
|
||||
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
}{}
|
||||
@ -296,7 +422,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
return
|
||||
}
|
||||
|
||||
err = DeleteNode(sess.UserID, req.NodeID)
|
||||
err = session.DeleteNode(req.NodeID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -306,9 +432,15 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
"OK": true,
|
||||
})
|
||||
} // }}}
|
||||
func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/upload")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse our multipart form, 10 << 20 specifies a maximum
|
||||
// upload of 10 MB files.
|
||||
@ -347,7 +479,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
MIME: handler.Header.Get("Content-Type"),
|
||||
MD5: md5sum,
|
||||
}
|
||||
if err = AddFile(sess.UserID, &nodeFile); err != nil {
|
||||
if err = session.AddFile(&nodeFile); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
@ -384,11 +516,17 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
"File": nodeFile,
|
||||
})
|
||||
} // }}}
|
||||
func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/download")
|
||||
var err error
|
||||
var session Session
|
||||
var files []File
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
FileID int
|
||||
@ -398,7 +536,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
||||
return
|
||||
}
|
||||
|
||||
files, err = Files(sess.UserID, req.NodeID, req.FileID)
|
||||
files, err = session.Files(req.NodeID, req.FileID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -445,11 +583,17 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
||||
}
|
||||
|
||||
} // }}}
|
||||
func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/files")
|
||||
var err error
|
||||
var session Session
|
||||
var files []File
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
}{}
|
||||
@ -458,7 +602,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
return
|
||||
}
|
||||
|
||||
files, err = Files(sess.UserID, req.NodeID, 0)
|
||||
files, err = session.Files(req.NodeID, 0)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -469,11 +613,17 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
"Files": files,
|
||||
})
|
||||
} // }}}
|
||||
func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/search")
|
||||
var err error
|
||||
var session Session
|
||||
var nodes []Node
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Search string
|
||||
}{}
|
||||
@ -482,7 +632,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
return
|
||||
}
|
||||
|
||||
nodes, err = SearchNodes(sess.UserID, req.Search)
|
||||
nodes, err = session.SearchNodes(req.Search)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -494,11 +644,17 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
})
|
||||
} // }}}
|
||||
|
||||
func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/key/retrieve")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
keys, err := Keys(sess.UserID)
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := session.Keys()
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -509,9 +665,15 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {
|
||||
"Keys": keys,
|
||||
})
|
||||
} // }}}
|
||||
func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/key/create")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Description string
|
||||
@ -522,7 +684,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
return
|
||||
}
|
||||
|
||||
key, err := KeyCreate(sess.UserID, req.Description, req.Key)
|
||||
key, err := session.KeyCreate(req.Description, req.Key)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -533,11 +695,17 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
"Key": key,
|
||||
})
|
||||
} // }}}
|
||||
func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/key/counter")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
counter, err := KeyCounter()
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
counter, err := session.KeyCounter()
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
@ -550,4 +718,20 @@ func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
||||
})
|
||||
} // }}}
|
||||
|
||||
func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
|
||||
// Append index.html if needed for further reading of the file
|
||||
p := requestPath
|
||||
if p[len(p)-1] == '/' {
|
||||
p += "index.html"
|
||||
}
|
||||
p = config.Application.Directories.Static + p
|
||||
|
||||
base := path.Base(p)
|
||||
if tmpl, err = template.New(base).ParseFiles(p); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
|
||||
// vim: foldmethod=marker
|
||||
|
74
node.go
74
node.go
@ -25,9 +25,9 @@ type Node struct {
|
||||
ContentEncrypted string `db:"content_encrypted" json:"-"`
|
||||
}
|
||||
|
||||
func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
||||
func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
WITH RECURSIVE nodetree AS (
|
||||
SELECT
|
||||
*,
|
||||
@ -62,7 +62,7 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
||||
ORDER BY
|
||||
path ASC
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
startNodeID,
|
||||
)
|
||||
if err != nil {
|
||||
@ -79,9 +79,6 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
||||
for rows.Next() {
|
||||
node := Node{}
|
||||
node.Complete = false
|
||||
node.Crumbs = []Node{}
|
||||
node.Children = []Node{}
|
||||
node.Files = []File{}
|
||||
if err = rows.StructScan(&node); err != nil {
|
||||
return
|
||||
}
|
||||
@ -90,9 +87,9 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func RootNode(userID int) (node Node, err error) {// {{{
|
||||
func (session Session) RootNode() (node Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
@ -103,7 +100,7 @@ func RootNode(userID int) (node Node, err error) {// {{{
|
||||
user_id = $1 AND
|
||||
parent_id IS NULL
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
@ -111,7 +108,7 @@ func RootNode(userID int) (node Node, err error) {// {{{
|
||||
defer rows.Close()
|
||||
|
||||
node.Name = "Start"
|
||||
node.UserID = userID
|
||||
node.UserID = session.UserID
|
||||
node.Complete = true
|
||||
node.Children = []Node{}
|
||||
node.Crumbs = []Node{}
|
||||
@ -132,13 +129,13 @@ func RootNode(userID int) (node Node, err error) {// {{{
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
|
||||
func (session Session) Node(nodeID int) (node Node, err error) {// {{{
|
||||
if nodeID == 0 {
|
||||
return RootNode(userID)
|
||||
return session.RootNode()
|
||||
}
|
||||
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
WITH RECURSIVE recurse AS (
|
||||
SELECT
|
||||
id,
|
||||
@ -173,7 +170,7 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
|
||||
|
||||
SELECT * FROM recurse ORDER BY level ASC
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
)
|
||||
if err != nil {
|
||||
@ -220,14 +217,14 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
|
||||
}
|
||||
}
|
||||
|
||||
node.Crumbs, err = NodeCrumbs(node.ID)
|
||||
node.Files, err = Files(userID, node.ID, 0)
|
||||
node.Crumbs, err = session.NodeCrumbs(node.ID)
|
||||
node.Files, err = session.Files(node.ID, 0)
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
|
||||
func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
WITH RECURSIVE nodes AS (
|
||||
SELECT
|
||||
id,
|
||||
@ -263,10 +260,10 @@ func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
|
||||
}
|
||||
return
|
||||
}// }}}
|
||||
func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{
|
||||
func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
INSERT INTO node(user_id, parent_id, name)
|
||||
VALUES($1, NULLIF($2, 0)::integer, $3)
|
||||
RETURNING
|
||||
@ -276,7 +273,7 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{
|
||||
name,
|
||||
content
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
parentID,
|
||||
name,
|
||||
)
|
||||
@ -295,12 +292,12 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{
|
||||
node.Complete = true
|
||||
}
|
||||
|
||||
node.Crumbs, err = NodeCrumbs(node.ID)
|
||||
node.Crumbs, err = session.NodeCrumbs(node.ID)
|
||||
return
|
||||
}// }}}
|
||||
func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{
|
||||
func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{
|
||||
if cryptoKeyID > 0 {
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
_, err = db.Exec(`
|
||||
UPDATE node
|
||||
SET
|
||||
content = '',
|
||||
@ -316,10 +313,10 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error)
|
||||
content,
|
||||
cryptoKeyID,
|
||||
nodeID,
|
||||
userID,
|
||||
session.UserID,
|
||||
)
|
||||
} else {
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
_, err = db.Exec(`
|
||||
UPDATE node
|
||||
SET
|
||||
content = $1,
|
||||
@ -335,24 +332,24 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error)
|
||||
content,
|
||||
cryptoKeyID,
|
||||
nodeID,
|
||||
userID,
|
||||
session.UserID,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func RenameNode(userID, nodeID int, name string) (err error) {// {{{
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{
|
||||
_, err = db.Exec(`
|
||||
UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3
|
||||
`,
|
||||
name,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
)
|
||||
return
|
||||
}// }}}
|
||||
func DeleteNode(userID, nodeID int) (err error) {// {{{
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
func (session Session) DeleteNode(nodeID int) (err error) {// {{{
|
||||
_, err = db.Exec(`
|
||||
WITH RECURSIVE nodetree AS (
|
||||
SELECT
|
||||
id, parent_id
|
||||
@ -371,15 +368,15 @@ func DeleteNode(userID, nodeID int) (err error) {// {{{
|
||||
DELETE FROM node WHERE id IN (
|
||||
SELECT id FROM nodetree
|
||||
)`,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
)
|
||||
return
|
||||
}// }}}
|
||||
func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
|
||||
func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{
|
||||
nodes = []Node{}
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
@ -388,15 +385,14 @@ func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
|
||||
updated
|
||||
FROM node
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
crypto_key_id IS NULL AND
|
||||
(
|
||||
content ~* $2 OR
|
||||
name ~* $2
|
||||
content ~* $1 OR
|
||||
name ~* $1
|
||||
)
|
||||
ORDER BY
|
||||
updated DESC
|
||||
`, userID, search)
|
||||
`, search)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
119
session.go
Normal file
119
session.go
Normal file
@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
// Standard
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
UserID int
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
func CreateSession() (session Session, err error) {// {{{
|
||||
var rows *sql.Rows
|
||||
if rows, err = db.Query(`
|
||||
INSERT INTO public.session(uuid)
|
||||
VALUES(gen_random_uuid())
|
||||
RETURNING uuid, created`,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
rows.Scan(&session.UUID, &session.Created)
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
|
||||
func sessionUUID(r *http.Request) (string, error) {// {{{
|
||||
headers := r.Header["X-Session-Id"]
|
||||
if len(headers) > 0 {
|
||||
return headers[0], nil
|
||||
}
|
||||
return "", errors.New("Invalid session")
|
||||
}// }}}
|
||||
func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{
|
||||
var uuid string
|
||||
if uuid, err = sessionUUID(r); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
session.UUID = uuid
|
||||
if found, err = session.Retrieve(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if notFoundIsError && !found {
|
||||
err = errors.New("Invalid session")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
|
||||
func (session *Session) Retrieve() (found bool, err error) {// {{{
|
||||
var rows *sql.Rows
|
||||
if rows, err = db.Query(`
|
||||
SELECT
|
||||
uuid, user_id, created
|
||||
FROM public.session
|
||||
WHERE
|
||||
uuid = $1 AND
|
||||
created + $2::interval >= NOW()
|
||||
`,
|
||||
session.UUID,
|
||||
fmt.Sprintf("%d days", config.Session.DaysValid),
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
found = false
|
||||
if rows.Next() {
|
||||
found = true
|
||||
rows.Scan(&session.UUID, &session.UserID, &session.Created)
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{
|
||||
var rows *sql.Rows
|
||||
if rows, err = db.Query(`
|
||||
SELECT id
|
||||
FROM public.user
|
||||
WHERE
|
||||
username=$1 AND
|
||||
password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
|
||||
`,
|
||||
username,
|
||||
password,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
rows.Scan(&session.UserID)
|
||||
authenticated = session.UserID > 0
|
||||
}
|
||||
|
||||
if authenticated {
|
||||
_, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
|
||||
|
||||
// vim: foldmethod=marker
|
34
sql/0013.sql
Normal file
34
sql/0013.sql
Normal file
@ -0,0 +1,34 @@
|
||||
/* Required for the gen_random_bytes function */
|
||||
CREATE EXTENSION pgcrypto;
|
||||
|
||||
CREATE FUNCTION password_hash(salt_hex char(32), pass bytea)
|
||||
RETURNS char(96)
|
||||
LANGUAGE plpgsql
|
||||
AS
|
||||
$$
|
||||
BEGIN
|
||||
RETURN (
|
||||
SELECT
|
||||
salt_hex ||
|
||||
encode(
|
||||
sha256(
|
||||
decode(salt_hex, 'hex') || /* salt in binary */
|
||||
pass /* password */
|
||||
),
|
||||
'hex'
|
||||
)
|
||||
);
|
||||
END;
|
||||
$$;
|
||||
|
||||
/* Password has to be able to accommodate 96 characters instead of previous 64.
|
||||
* It can't be char(96), because then the password would be padded to 96 characters. */
|
||||
ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar;
|
||||
|
||||
/* Update all users with salted and hashed passwords */
|
||||
UPDATE public.user
|
||||
SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea);
|
||||
|
||||
/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash,
|
||||
* and the varchar type is not longer necessary. */
|
||||
ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar;
|
@ -13,8 +13,8 @@ class App extends Component {
|
||||
this.websocket = null
|
||||
this.websocket_int_ping = null
|
||||
this.websocket_int_reconnect = null
|
||||
//this.wsConnect() // XXX
|
||||
//this.wsLoop() // XXX
|
||||
this.wsConnect()
|
||||
this.wsLoop()
|
||||
|
||||
this.session = new Session(this)
|
||||
this.session.initialize()
|
||||
@ -114,7 +114,7 @@ class App extends Component {
|
||||
}
|
||||
|
||||
if(this.session.UUID !== '')
|
||||
headers['X-Session-ID'] = this.session.UUID
|
||||
headers['X-Session-Id'] = this.session.UUID
|
||||
|
||||
fetch(url, {
|
||||
method: 'POST',
|
||||
@ -201,24 +201,9 @@ class Tree extends Component {
|
||||
this.selectedTreeNode = null
|
||||
|
||||
this.props.app.tree = this
|
||||
this.retrieve()
|
||||
}//}}}
|
||||
render({ app }) {//{{{
|
||||
let renderedTreeTrunk = this.treeTrunk.map(node=>{
|
||||
this.treeNodeComponents[node.ID] = createRef()
|
||||
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
|
||||
})
|
||||
return html`<div id="tree">${renderedTreeTrunk}</div>`
|
||||
}//}}}
|
||||
|
||||
retrieve(callback = null) {//{{{
|
||||
this.props.app.request('/node/tree', { StartNodeID: 0 })
|
||||
.then(res=>{
|
||||
this.treeNodes = {}
|
||||
this.treeNodeComponents = {}
|
||||
this.treeTrunk = []
|
||||
this.selectedTreeNode = null
|
||||
|
||||
// A tree of nodes is built. This requires the list of nodes
|
||||
// returned from the server to be sorted in such a way that
|
||||
// a parent node always appears before a child node.
|
||||
@ -250,12 +235,17 @@ class Tree extends Component {
|
||||
this.crumbsUpdateNodes()
|
||||
this.forceUpdate()
|
||||
|
||||
if(callback)
|
||||
callback()
|
||||
|
||||
})
|
||||
.catch(this.responseError)
|
||||
}//}}}
|
||||
render({ app }) {//{{{
|
||||
let renderedTreeTrunk = this.treeTrunk.map(node=>{
|
||||
this.treeNodeComponents[node.ID] = createRef()
|
||||
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
|
||||
})
|
||||
return html`<div id="tree">${renderedTreeTrunk}</div>`
|
||||
}//}}}
|
||||
|
||||
setSelected(node) {//{{{
|
||||
if(this.selectedTreeNode)
|
||||
this.selectedTreeNode.selected.value = false
|
||||
|
@ -215,14 +215,7 @@ export class NodeUI extends Component {
|
||||
let name = prompt("Name")
|
||||
if(!name)
|
||||
return
|
||||
this.node.value.create(name, nodeID=>{
|
||||
console.log('before', this.props.app.startNode)
|
||||
this.props.app.startNode = new Node(this.props.app, nodeID)
|
||||
console.log('after', this.props.app.startNode)
|
||||
this.props.app.tree.retrieve(()=>{
|
||||
this.goToNode(nodeID)
|
||||
})
|
||||
})
|
||||
this.node.value.create(name, nodeID=>this.goToNode(nodeID))
|
||||
}//}}}
|
||||
saveNode() {//{{{
|
||||
let nodeContent = this.nodeContent.current
|
||||
|
@ -1,15 +1,16 @@
|
||||
export class Session {
|
||||
constructor(app) {//{{{
|
||||
constructor(app) {
|
||||
this.app = app
|
||||
this.UUID = ''
|
||||
this.initialized = false
|
||||
this.UserID = 0
|
||||
}//}}}
|
||||
}
|
||||
|
||||
initialize() {//{{{
|
||||
// Retrieving the stored session UUID, if any.
|
||||
// If one found, validate with server.
|
||||
|
||||
|
||||
// If the browser doesn't know anything about a session,
|
||||
// a call to /session/create is necessary to retrieve a session UUID.
|
||||
let uuid= window.localStorage.getItem("session.UUID")
|
||||
@ -24,9 +25,9 @@ export class Session {
|
||||
// A call to /session/retrieve with a session UUID validates that the
|
||||
// session is still valid and returns all session information.
|
||||
this.UUID = uuid
|
||||
this.app.request('/_session/retrieve', {})
|
||||
this.app.request('/session/retrieve', {})
|
||||
.then(res=>{
|
||||
if (res.Error === undefined) {
|
||||
if(res.Valid) {
|
||||
// Session exists on server.
|
||||
// Not necessarily authenticated.
|
||||
this.UserID = res.Session.UserID // could be 0
|
||||
@ -40,10 +41,10 @@ export class Session {
|
||||
.catch(this.app.responseError)
|
||||
}//}}}
|
||||
create() {//{{{
|
||||
this.app.request('/_session/new', {})
|
||||
this.app.request('/session/create', {})
|
||||
.then(res=>{
|
||||
this.UUID = res.Session.UUID
|
||||
window.localStorage.setItem('session.UUID', this.Session.UUID)
|
||||
window.localStorage.setItem('session.UUID', this.UUID)
|
||||
this.initialized = true
|
||||
this.app.forceUpdate()
|
||||
})
|
||||
@ -52,13 +53,13 @@ export class Session {
|
||||
authenticate(username, password) {//{{{
|
||||
this.app.login.current.authentication_failed.value = false
|
||||
|
||||
this.app.request('/_session/authenticate', {
|
||||
this.app.request('/session/authenticate', {
|
||||
username,
|
||||
password,
|
||||
})
|
||||
.then(res=>{
|
||||
if(res.Authenticated) {
|
||||
this.UserID = res.UserID
|
||||
this.UserID = res.Session.UserID
|
||||
this.app.forceUpdate()
|
||||
} else {
|
||||
this.app.login.current.authentication_failed.value = true
|
||||
|
47
user.go
47
user.go
@ -1,6 +1,11 @@
|
||||
package main
|
||||
|
||||
/*
|
||||
import (
|
||||
// Standard
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) {
|
||||
var result sql.Result
|
||||
var rowsAffected int64
|
||||
@ -9,10 +14,10 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
|
||||
UPDATE public.user
|
||||
SET
|
||||
password = password_hash(
|
||||
/ salt in hex /
|
||||
/* salt in hex */
|
||||
ENCODE(gen_random_bytes(16), 'hex'),
|
||||
|
||||
/ password /
|
||||
/* password */
|
||||
$1::bytea
|
||||
)
|
||||
WHERE
|
||||
@ -31,4 +36,38 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
|
||||
|
||||
return rowsAffected > 0, nil
|
||||
}
|
||||
*/
|
||||
|
||||
func createUser() (err error) {
|
||||
var username, password string
|
||||
fmt.Printf("Username: ")
|
||||
fmt.Scanln(&username)
|
||||
fmt.Printf("Password: ")
|
||||
fmt.Scanln(&password)
|
||||
|
||||
err = CreateDbUser(username, password)
|
||||
return
|
||||
}
|
||||
|
||||
func CreateDbUser(username string, password string) (err error) {
|
||||
var result sql.Result
|
||||
|
||||
result, err = db.Exec(`
|
||||
INSERT INTO public.user(username, password)
|
||||
VALUES(
|
||||
$1,
|
||||
password_hash(
|
||||
/* salt in hex */
|
||||
ENCODE(gen_random_bytes(16), 'hex'),
|
||||
|
||||
/* password */
|
||||
$2::bytea
|
||||
)
|
||||
)
|
||||
`,
|
||||
username,
|
||||
password,
|
||||
)
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user