Compare commits
No commits in common. "afa113b8b7066c8db589434fe5fe53f6d34cac3c" and "abbd320b93b849b4bbaa212a17e4c3c8b27f18ef" have entirely different histories.
afa113b8b7
...
abbd320b93
23 changed files with 715 additions and 205 deletions
150
db.go
Normal file
150
db.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// External
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
// Standard
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
dbConn string
|
||||
db *sqlx.DB
|
||||
)
|
||||
|
||||
func dbInit() (err error) { // {{{
|
||||
dbConn = fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
config.Database.Host,
|
||||
config.Database.Port,
|
||||
config.Database.Username,
|
||||
config.Database.Password,
|
||||
config.Database.Name,
|
||||
)
|
||||
|
||||
logger.Info("db", "op", "connect", "host", config.Database.Host, "port", config.Database.Port)
|
||||
|
||||
if db, err = sqlx.Connect("postgres", dbConn); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = dbVerifyInternals(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = dbUpdate()
|
||||
return
|
||||
} // }}}
|
||||
func dbVerifyInternals() (err error) { // {{{
|
||||
var rows *sqlx.Rows
|
||||
if rows, err = db.Queryx(
|
||||
`SELECT EXISTS (
|
||||
SELECT FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = '_internal'
|
||||
AND c.relname = 'db'
|
||||
)`,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var exists bool
|
||||
rows.Next()
|
||||
if err = rows.Scan(&exists); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !exists {
|
||||
logger.Info("db", "op", "create_db", "db", "_internal.db")
|
||||
if _, err = db.Exec(`
|
||||
CREATE SCHEMA "_internal";
|
||||
|
||||
CREATE TABLE "_internal".db (
|
||||
"key" varchar NOT NULL,
|
||||
value varchar NULL,
|
||||
CONSTRAINT db_pk PRIMARY KEY (key)
|
||||
);
|
||||
|
||||
INSERT INTO _internal.db("key", "value")
|
||||
VALUES('schema', '0');
|
||||
`,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
} // }}}
|
||||
func dbUpdate() (err error) { // {{{
|
||||
/* Current schema revision is read from database.
|
||||
* Used to iterate through the embedded SQL updates
|
||||
* up to the db schema version currently compiled
|
||||
* program is made for. */
|
||||
var rows *sqlx.Rows
|
||||
var schemaStr string
|
||||
var schema int
|
||||
rows, err = db.Queryx(`SELECT value FROM _internal.db WHERE "key"='schema'`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return errors.New("Table _interval.db missing schema row")
|
||||
}
|
||||
|
||||
if err = rows.Scan(&schemaStr); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Run updates
|
||||
schema, err = strconv.Atoi(schemaStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sqlSchemaVersion := sqlSchema()
|
||||
for i := (schema + 1); i <= sqlSchemaVersion; i++ {
|
||||
logger.Info("db", "op", "upgrade_schema", "schema", i)
|
||||
sql, _ := embedded.ReadFile(
|
||||
fmt.Sprintf("sql/%04d.sql", i),
|
||||
)
|
||||
_, err = db.Exec(string(sql))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = db.Exec(`UPDATE _internal.db SET "value"=$1 WHERE "key"='schema'`, i)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
logger.Info("db", "op", "upgrade_schema", "schema", i, "result", "ok")
|
||||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
func sqlSchema() (max int) { // {{{
|
||||
var num int
|
||||
|
||||
files, _ := fs.ReadDir(embedded, "sql")
|
||||
sqlFilename := regexp.MustCompile(`^([0-9]+)\.sql$`)
|
||||
|
||||
for _, file := range files {
|
||||
fname := sqlFilename.FindStringSubmatch(file.Name())
|
||||
if len(fname) != 2 {
|
||||
continue
|
||||
}
|
||||
num, _ = strconv.Atoi(fname[1])
|
||||
}
|
||||
|
||||
if num > max {
|
||||
max = num
|
||||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
|
||||
// vim: foldmethod=marker
|
||||
14
file.go
14
file.go
|
|
@ -5,6 +5,7 @@ import (
|
|||
"github.com/jmoiron/sqlx"
|
||||
|
||||
// Standard
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -19,11 +20,11 @@ type File struct {
|
|||
Uploaded time.Time
|
||||
}
|
||||
|
||||
func AddFile(userID int, file *File) (err error) { // {{{
|
||||
file.UserID = userID
|
||||
func (session Session) AddFile(file *File) (err error) { // {{{
|
||||
file.UserID = session.UserID
|
||||
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
INSERT INTO file(user_id, node_id, filename, size, mime, md5)
|
||||
VALUES($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id
|
||||
|
|
@ -42,11 +43,12 @@ func AddFile(userID int, file *File) (err error) { // {{{
|
|||
|
||||
rows.Next()
|
||||
err = rows.Scan(&file.ID)
|
||||
fmt.Printf("%#v\n", file)
|
||||
return
|
||||
} // }}}
|
||||
func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
|
||||
func (session Session) Files(nodeID, fileID int) (files []File, err error) { // {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(
|
||||
rows, err = db.Queryx(
|
||||
`SELECT *
|
||||
FROM file
|
||||
WHERE
|
||||
|
|
@ -56,7 +58,7 @@ func Files(userID, nodeID, fileID int) (files []File, err error) { // {{{
|
|||
WHEN 0 THEN true
|
||||
ELSE id = $3
|
||||
END`,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
fileID,
|
||||
)
|
||||
|
|
|
|||
22
key.go
22
key.go
|
|
@ -3,6 +3,8 @@ package main
|
|||
import (
|
||||
// External
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
// Standard
|
||||
)
|
||||
|
||||
type Key struct {
|
||||
|
|
@ -12,9 +14,9 @@ type Key struct {
|
|||
Key string
|
||||
}
|
||||
|
||||
func Keys(userID int) (keys []Key, err error) { // {{{
|
||||
func (session Session) Keys() (keys []Key, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
if rows, err = service.Db.Conn.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, userID); err != nil {
|
||||
if rows, err = db.Queryx(`SELECT * FROM crypto_key WHERE user_id=$1`, session.UserID); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
|
@ -29,14 +31,14 @@ func Keys(userID int) (keys []Key, err error) { // {{{
|
|||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
func KeyCreate(userID int, description, keyEncoded string) (key Key, err error) { // {{{
|
||||
}// }}}
|
||||
func (session Session) KeyCreate(description, keyEncoded string) (key Key, err error) {// {{{
|
||||
var row *sqlx.Rows
|
||||
if row, err = service.Db.Conn.Queryx(
|
||||
if row, err = db.Queryx(
|
||||
`INSERT INTO crypto_key(user_id, description, key)
|
||||
VALUES($1, $2, $3)
|
||||
RETURNING *`,
|
||||
userID,
|
||||
session.UserID,
|
||||
description,
|
||||
keyEncoded,
|
||||
); err != nil {
|
||||
|
|
@ -52,10 +54,10 @@ func KeyCreate(userID int, description, keyEncoded string) (key Key, err error)
|
|||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
func KeyCounter() (counter int64, err error) { // {{{
|
||||
}// }}}
|
||||
func (session Session) KeyCounter() (counter int64, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
|
||||
rows, err = db.Queryx(`SELECT nextval('aes_ccm_counter') AS counter`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -64,6 +66,6 @@ func KeyCounter() (counter int64, err error) { // {{{
|
|||
rows.Next()
|
||||
err = rows.Scan(&counter)
|
||||
return
|
||||
} // }}}
|
||||
}// }}}
|
||||
|
||||
// vim: foldmethod=marker
|
||||
|
|
|
|||
346
main.go
346
main.go
|
|
@ -1,23 +1,20 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// External
|
||||
"git.gibonuddevalla.se/go/webservice"
|
||||
|
||||
// Internal
|
||||
"git.gibonuddevalla.se/go/webservice/session"
|
||||
|
||||
// Standard
|
||||
"crypto/md5"
|
||||
"embed"
|
||||
"encoding/hex"
|
||||
"flag"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
|
@ -28,10 +25,8 @@ var (
|
|||
flagPort int
|
||||
flagVersion bool
|
||||
flagCreateUser bool
|
||||
flagCheckLocal bool
|
||||
flagConfig string
|
||||
|
||||
service *webservice.Service
|
||||
connectionManager ConnectionManager
|
||||
static http.Handler
|
||||
config Config
|
||||
|
|
@ -39,24 +34,11 @@ var (
|
|||
VERSION string
|
||||
|
||||
//go:embed version sql/*
|
||||
embeddedSQL embed.FS
|
||||
|
||||
//go:embed static
|
||||
staticFS embed.FS
|
||||
embedded embed.FS
|
||||
)
|
||||
|
||||
func sqlProvider(dbname string, version int) (sql []byte, found bool) {
|
||||
var err error
|
||||
sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
found = true
|
||||
return
|
||||
}
|
||||
|
||||
func init() { // {{{
|
||||
version, _ := embeddedSQL.ReadFile("version")
|
||||
version, _ := embedded.ReadFile("version")
|
||||
VERSION = strings.TrimSpace(string(version))
|
||||
|
||||
opt := slog.HandlerOptions{}
|
||||
|
|
@ -66,7 +48,6 @@ func init() { // {{{
|
|||
flag.IntVar(&flagPort, "port", 1371, "TCP port to listen on")
|
||||
flag.BoolVar(&flagVersion, "version", false, "Shows Notes version and exists")
|
||||
flag.BoolVar(&flagCreateUser, "createuser", false, "Create a user and exit")
|
||||
flag.BoolVar(&flagCheckLocal, "checklocal", false, "Check for local static file before embedded")
|
||||
flag.StringVar(&flagConfig, "config", configFilename, "Filename of configuration file")
|
||||
flag.Parse()
|
||||
} // }}}
|
||||
|
|
@ -77,47 +58,56 @@ func main() { // {{{
|
|||
fmt.Printf("%s\n", VERSION)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
logger.Info("application", "version", VERSION)
|
||||
|
||||
service, err = webservice.New(flagConfig, VERSION)
|
||||
config, err = ConfigRead(flagConfig)
|
||||
if err != nil {
|
||||
logger.Error("application", "error", err)
|
||||
logger.Error("config", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
service.SetDatabase(sqlProvider)
|
||||
service.SetStaticDirectory("static", true)
|
||||
service.SetStaticFS(staticFS, "static")
|
||||
|
||||
service.Register("/node/upload", true, true, nodeUpload)
|
||||
service.Register("/node/tree", true, true, nodeTree)
|
||||
service.Register("/node/retrieve", true, true, nodeRetrieve)
|
||||
service.Register("/node/create", true, true, nodeCreate)
|
||||
service.Register("/node/update", true, true, nodeUpdate)
|
||||
service.Register("/node/rename", true, true, nodeRename)
|
||||
service.Register("/node/delete", true, true, nodeDelete)
|
||||
service.Register("/node/download", true, true, nodeDownload)
|
||||
service.Register("/node/search", true, true, nodeSearch)
|
||||
service.Register("/key/retrieve", true, true, keyRetrieve)
|
||||
service.Register("/key/create", true, true, keyCreate)
|
||||
service.Register("/key/counter", true, true, keyCounter)
|
||||
service.Register("/ws", false, false, service.WebsocketHandler)
|
||||
service.Register("/", false, false, service.StaticHandler)
|
||||
if err = dbInit(); err != nil {
|
||||
logger.Error("db", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if flagCreateUser {
|
||||
service.CreateUserPrompt()
|
||||
err = createUser()
|
||||
if err != nil {
|
||||
logger.Error("db", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err = service.Start()
|
||||
if err != nil {
|
||||
logger.Error("webserver", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
connectionManager = NewConnectionManager()
|
||||
go connectionManager.BroadcastLoop()
|
||||
|
||||
static = http.FileServer(http.Dir(config.Application.Directories.Static))
|
||||
http.HandleFunc("/css_updated", cssUpdateHandler)
|
||||
http.HandleFunc("/session/create", sessionCreate)
|
||||
http.HandleFunc("/session/retrieve", sessionRetrieve)
|
||||
http.HandleFunc("/session/authenticate", sessionAuthenticate)
|
||||
http.HandleFunc("/user/password", userPassword)
|
||||
http.HandleFunc("/node/tree", nodeTree)
|
||||
http.HandleFunc("/node/retrieve", nodeRetrieve)
|
||||
http.HandleFunc("/node/create", nodeCreate)
|
||||
http.HandleFunc("/node/update", nodeUpdate)
|
||||
http.HandleFunc("/node/rename", nodeRename)
|
||||
http.HandleFunc("/node/delete", nodeDelete)
|
||||
http.HandleFunc("/node/upload", nodeUpload)
|
||||
http.HandleFunc("/node/download", nodeDownload)
|
||||
http.HandleFunc("/node/search", nodeSearch)
|
||||
http.HandleFunc("/key/retrieve", keyRetrieve)
|
||||
http.HandleFunc("/key/create", keyCreate)
|
||||
http.HandleFunc("/key/counter", keyCounter)
|
||||
http.HandleFunc("/ws", websocketHandler)
|
||||
http.HandleFunc("/", staticHandler)
|
||||
|
||||
listen := fmt.Sprintf("%s:%d", LISTEN_HOST, flagPort)
|
||||
logger.Info("webserver", "listen", listen, "domains", config.Websocket.Domains)
|
||||
http.ListenAndServe(listen, nil)
|
||||
} // }}}
|
||||
|
||||
func cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
|
|
@ -137,8 +127,108 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{
|
|||
return
|
||||
}
|
||||
} // }}}
|
||||
func staticHandler(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
data := struct {
|
||||
VERSION string
|
||||
}{
|
||||
VERSION: VERSION,
|
||||
}
|
||||
|
||||
// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
|
||||
// To get rid of problems with cached content in browser on a new version release,
|
||||
// while also not disabling cache altogether.
|
||||
logger.Debug("webserver", "request", r.URL.Path)
|
||||
if r.URL.Path == "/favicon.ico" {
|
||||
static.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
|
||||
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
|
||||
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
|
||||
static.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Everything else is run through the template system.
|
||||
// For now to get VERSION into files to fix caching.
|
||||
logger.Debug("webserver", "template", r.URL.Path)
|
||||
tmpl, err := newTemplate(r.URL.Path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err = tmpl.Execute(w, data); err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
}
|
||||
} // }}}
|
||||
|
||||
func sessionCreate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/session/create")
|
||||
session, err := CreateSession()
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
responseData(w, map[string]interface{}{
|
||||
"OK": true,
|
||||
"Session": session,
|
||||
})
|
||||
} // }}}
|
||||
func sessionRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/session/retrieve")
|
||||
var err error
|
||||
var found bool
|
||||
var session Session
|
||||
|
||||
if session, found, err = ValidateSession(r, false); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
responseData(w, map[string]interface{}{
|
||||
"OK": true,
|
||||
"Valid": found,
|
||||
"Session": session,
|
||||
})
|
||||
} // }}}
|
||||
func sessionAuthenticate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/session/authenticate")
|
||||
var err error
|
||||
var session Session
|
||||
var authenticated bool
|
||||
|
||||
// Validate session
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Username string
|
||||
Password string
|
||||
}{}
|
||||
if err = parseRequest(r, &req); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if authenticated, err = session.Authenticate(req.Username, req.Password); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
responseData(w, map[string]interface{}{
|
||||
"OK": true,
|
||||
"Authenticated": authenticated,
|
||||
"Session": session,
|
||||
})
|
||||
} // }}}
|
||||
|
||||
/*
|
||||
func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
var err error
|
||||
var ok bool
|
||||
|
|
@ -169,11 +259,15 @@ func userPassword(w http.ResponseWriter, r *http.Request) { // {{{
|
|||
"CurrentPasswordOK": ok,
|
||||
})
|
||||
} // }}}
|
||||
*/
|
||||
|
||||
func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
logger.Info("webserver", "request", "/node/tree")
|
||||
func nodeTree(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct{ StartNodeID int }{}
|
||||
if err = parseRequest(r, &req); err != nil {
|
||||
|
|
@ -181,7 +275,7 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
|||
return
|
||||
}
|
||||
|
||||
nodes, err := NodeTree(sess.UserID, req.StartNodeID)
|
||||
nodes, err := session.NodeTree(req.StartNodeID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -192,9 +286,15 @@ func nodeTree(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
|||
"Nodes": nodes,
|
||||
})
|
||||
} // }}}
|
||||
func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/retrieve")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct{ ID int }{}
|
||||
if err = parseRequest(r, &req); err != nil {
|
||||
|
|
@ -202,7 +302,7 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
|||
return
|
||||
}
|
||||
|
||||
node, err := RetrieveNode(sess.UserID, req.ID)
|
||||
node, err := session.Node(req.ID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -213,9 +313,15 @@ func nodeRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
|||
"Node": node,
|
||||
})
|
||||
} // }}}
|
||||
func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeCreate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/create")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Name string
|
||||
|
|
@ -226,7 +332,7 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
return
|
||||
}
|
||||
|
||||
node, err := CreateNode(sess.UserID, req.ParentID, req.Name)
|
||||
node, err := session.CreateNode(req.ParentID, req.Name)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -237,9 +343,15 @@ func nodeCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
"Node": node,
|
||||
})
|
||||
} // }}}
|
||||
func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeUpdate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/update")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
|
|
@ -251,7 +363,7 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
return
|
||||
}
|
||||
|
||||
err = UpdateNode(sess.UserID, req.NodeID, req.Content, req.CryptoKeyID)
|
||||
err = session.UpdateNode(req.NodeID, req.Content, req.CryptoKeyID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -261,10 +373,17 @@ func nodeUpdate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
"OK": true,
|
||||
})
|
||||
} // }}}
|
||||
func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
var err error
|
||||
func nodeRename(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/rename")
|
||||
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
Name string
|
||||
|
|
@ -274,7 +393,7 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
return
|
||||
}
|
||||
|
||||
err = RenameNode(sess.UserID, req.NodeID, req.Name)
|
||||
err = session.RenameNode(req.NodeID, req.Name)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -284,10 +403,17 @@ func nodeRename(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
"OK": true,
|
||||
})
|
||||
} // }}}
|
||||
func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
var err error
|
||||
func nodeDelete(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/delete")
|
||||
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
}{}
|
||||
|
|
@ -296,7 +422,7 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
return
|
||||
}
|
||||
|
||||
err = DeleteNode(sess.UserID, req.NodeID)
|
||||
err = session.DeleteNode(req.NodeID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -306,9 +432,15 @@ func nodeDelete(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
"OK": true,
|
||||
})
|
||||
} // }}}
|
||||
func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeUpload(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/upload")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse our multipart form, 10 << 20 specifies a maximum
|
||||
// upload of 10 MB files.
|
||||
|
|
@ -347,7 +479,7 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
MIME: handler.Header.Get("Content-Type"),
|
||||
MD5: md5sum,
|
||||
}
|
||||
if err = AddFile(sess.UserID, &nodeFile); err != nil {
|
||||
if err = session.AddFile(&nodeFile); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
|
@ -384,11 +516,17 @@ func nodeUpload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
"File": nodeFile,
|
||||
})
|
||||
} // }}}
|
||||
func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeDownload(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/download")
|
||||
var err error
|
||||
var session Session
|
||||
var files []File
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
FileID int
|
||||
|
|
@ -398,7 +536,7 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
|||
return
|
||||
}
|
||||
|
||||
files, err = Files(sess.UserID, req.NodeID, req.FileID)
|
||||
files, err = session.Files(req.NodeID, req.FileID)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -445,11 +583,17 @@ func nodeDownload(w http.ResponseWriter, r *http.Request, sess *session.T) { //
|
|||
}
|
||||
|
||||
} // }}}
|
||||
func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeFiles(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/files")
|
||||
var err error
|
||||
var session Session
|
||||
var files []File
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
NodeID int
|
||||
}{}
|
||||
|
|
@ -458,7 +602,7 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
|||
return
|
||||
}
|
||||
|
||||
files, err = Files(sess.UserID, req.NodeID, 0)
|
||||
files, err = session.Files(req.NodeID, 0)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -469,11 +613,17 @@ func nodeFiles(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
|||
"Files": files,
|
||||
})
|
||||
} // }}}
|
||||
func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func nodeSearch(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/node/search")
|
||||
var err error
|
||||
var session Session
|
||||
var nodes []Node
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Search string
|
||||
}{}
|
||||
|
|
@ -482,7 +632,7 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
return
|
||||
}
|
||||
|
||||
nodes, err = SearchNodes(sess.UserID, req.Search)
|
||||
nodes, err = session.SearchNodes(req.Search)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -494,11 +644,17 @@ func nodeSearch(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
})
|
||||
} // }}}
|
||||
|
||||
func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func keyRetrieve(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/key/retrieve")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
keys, err := Keys(sess.UserID)
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := session.Keys()
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -509,9 +665,15 @@ func keyRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) { // {
|
|||
"Keys": keys,
|
||||
})
|
||||
} // }}}
|
||||
func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func keyCreate(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/key/create")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Description string
|
||||
|
|
@ -522,7 +684,7 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
|||
return
|
||||
}
|
||||
|
||||
key, err := KeyCreate(sess.UserID, req.Description, req.Key)
|
||||
key, err := session.KeyCreate(req.Description, req.Key)
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -533,11 +695,17 @@ func keyCreate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
|||
"Key": key,
|
||||
})
|
||||
} // }}}
|
||||
func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
|
||||
func keyCounter(w http.ResponseWriter, r *http.Request) { // {{{
|
||||
logger.Info("webserver", "request", "/key/counter")
|
||||
var err error
|
||||
var session Session
|
||||
|
||||
counter, err := KeyCounter()
|
||||
if session, _, err = ValidateSession(r, true); err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
counter, err := session.KeyCounter()
|
||||
if err != nil {
|
||||
responseError(w, err)
|
||||
return
|
||||
|
|
@ -550,4 +718,20 @@ func keyCounter(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{
|
|||
})
|
||||
} // }}}
|
||||
|
||||
func newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
|
||||
// Append index.html if needed for further reading of the file
|
||||
p := requestPath
|
||||
if p[len(p)-1] == '/' {
|
||||
p += "index.html"
|
||||
}
|
||||
p = config.Application.Directories.Static + p
|
||||
|
||||
base := path.Base(p)
|
||||
if tmpl, err = template.New(base).ParseFiles(p); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
} // }}}
|
||||
|
||||
// vim: foldmethod=marker
|
||||
|
|
|
|||
74
node.go
74
node.go
|
|
@ -25,9 +25,9 @@ type Node struct {
|
|||
ContentEncrypted string `db:"content_encrypted" json:"-"`
|
||||
}
|
||||
|
||||
func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
||||
func (session Session) NodeTree(startNodeID int) (nodes []Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
WITH RECURSIVE nodetree AS (
|
||||
SELECT
|
||||
*,
|
||||
|
|
@ -62,7 +62,7 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
|||
ORDER BY
|
||||
path ASC
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
startNodeID,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -79,9 +79,6 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
|||
for rows.Next() {
|
||||
node := Node{}
|
||||
node.Complete = false
|
||||
node.Crumbs = []Node{}
|
||||
node.Children = []Node{}
|
||||
node.Files = []File{}
|
||||
if err = rows.StructScan(&node); err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -90,9 +87,9 @@ func NodeTree(userID, startNodeID int) (nodes []Node, err error) {// {{{
|
|||
|
||||
return
|
||||
}// }}}
|
||||
func RootNode(userID int) (node Node, err error) {// {{{
|
||||
func (session Session) RootNode() (node Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
|
|
@ -103,7 +100,7 @@ func RootNode(userID int) (node Node, err error) {// {{{
|
|||
user_id = $1 AND
|
||||
parent_id IS NULL
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -111,7 +108,7 @@ func RootNode(userID int) (node Node, err error) {// {{{
|
|||
defer rows.Close()
|
||||
|
||||
node.Name = "Start"
|
||||
node.UserID = userID
|
||||
node.UserID = session.UserID
|
||||
node.Complete = true
|
||||
node.Children = []Node{}
|
||||
node.Crumbs = []Node{}
|
||||
|
|
@ -132,13 +129,13 @@ func RootNode(userID int) (node Node, err error) {// {{{
|
|||
|
||||
return
|
||||
}// }}}
|
||||
func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
|
||||
func (session Session) Node(nodeID int) (node Node, err error) {// {{{
|
||||
if nodeID == 0 {
|
||||
return RootNode(userID)
|
||||
return session.RootNode()
|
||||
}
|
||||
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
WITH RECURSIVE recurse AS (
|
||||
SELECT
|
||||
id,
|
||||
|
|
@ -173,7 +170,7 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
|
|||
|
||||
SELECT * FROM recurse ORDER BY level ASC
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -220,14 +217,14 @@ func RetrieveNode(userID, nodeID int) (node Node, err error) {// {{{
|
|||
}
|
||||
}
|
||||
|
||||
node.Crumbs, err = NodeCrumbs(node.ID)
|
||||
node.Files, err = Files(userID, node.ID, 0)
|
||||
node.Crumbs, err = session.NodeCrumbs(node.ID)
|
||||
node.Files, err = session.Files(node.ID, 0)
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
|
||||
func (session Session) NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
WITH RECURSIVE nodes AS (
|
||||
SELECT
|
||||
id,
|
||||
|
|
@ -263,10 +260,10 @@ func NodeCrumbs(nodeID int) (nodes []Node, err error) {// {{{
|
|||
}
|
||||
return
|
||||
}// }}}
|
||||
func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{{
|
||||
func (session Session) CreateNode(parentID int, name string) (node Node, err error) {// {{{
|
||||
var rows *sqlx.Rows
|
||||
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
INSERT INTO node(user_id, parent_id, name)
|
||||
VALUES($1, NULLIF($2, 0)::integer, $3)
|
||||
RETURNING
|
||||
|
|
@ -276,7 +273,7 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{
|
|||
name,
|
||||
content
|
||||
`,
|
||||
userID,
|
||||
session.UserID,
|
||||
parentID,
|
||||
name,
|
||||
)
|
||||
|
|
@ -295,12 +292,12 @@ func CreateNode(userID, parentID int, name string) (node Node, err error) {// {{
|
|||
node.Complete = true
|
||||
}
|
||||
|
||||
node.Crumbs, err = NodeCrumbs(node.ID)
|
||||
node.Crumbs, err = session.NodeCrumbs(node.ID)
|
||||
return
|
||||
}// }}}
|
||||
func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error) {// {{{
|
||||
func (session Session) UpdateNode(nodeID int, content string, cryptoKeyID int) (err error) {// {{{
|
||||
if cryptoKeyID > 0 {
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
_, err = db.Exec(`
|
||||
UPDATE node
|
||||
SET
|
||||
content = '',
|
||||
|
|
@ -316,10 +313,10 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error)
|
|||
content,
|
||||
cryptoKeyID,
|
||||
nodeID,
|
||||
userID,
|
||||
session.UserID,
|
||||
)
|
||||
} else {
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
_, err = db.Exec(`
|
||||
UPDATE node
|
||||
SET
|
||||
content = $1,
|
||||
|
|
@ -335,24 +332,24 @@ func UpdateNode(userID, nodeID int, content string, cryptoKeyID int) (err error)
|
|||
content,
|
||||
cryptoKeyID,
|
||||
nodeID,
|
||||
userID,
|
||||
session.UserID,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func RenameNode(userID, nodeID int, name string) (err error) {// {{{
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
func (session Session) RenameNode(nodeID int, name string) (err error) {// {{{
|
||||
_, err = db.Exec(`
|
||||
UPDATE node SET name = $1 WHERE user_id = $2 AND id = $3
|
||||
`,
|
||||
name,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
)
|
||||
return
|
||||
}// }}}
|
||||
func DeleteNode(userID, nodeID int) (err error) {// {{{
|
||||
_, err = service.Db.Conn.Exec(`
|
||||
func (session Session) DeleteNode(nodeID int) (err error) {// {{{
|
||||
_, err = db.Exec(`
|
||||
WITH RECURSIVE nodetree AS (
|
||||
SELECT
|
||||
id, parent_id
|
||||
|
|
@ -371,15 +368,15 @@ func DeleteNode(userID, nodeID int) (err error) {// {{{
|
|||
DELETE FROM node WHERE id IN (
|
||||
SELECT id FROM nodetree
|
||||
)`,
|
||||
userID,
|
||||
session.UserID,
|
||||
nodeID,
|
||||
)
|
||||
return
|
||||
}// }}}
|
||||
func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
|
||||
func (session Session) SearchNodes(search string) (nodes []Node, err error) {// {{{
|
||||
nodes = []Node{}
|
||||
var rows *sqlx.Rows
|
||||
rows, err = service.Db.Conn.Queryx(`
|
||||
rows, err = db.Queryx(`
|
||||
SELECT
|
||||
id,
|
||||
user_id,
|
||||
|
|
@ -388,15 +385,14 @@ func SearchNodes(userID int, search string) (nodes []Node, err error) {// {{{
|
|||
updated
|
||||
FROM node
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
crypto_key_id IS NULL AND
|
||||
(
|
||||
content ~* $2 OR
|
||||
name ~* $2
|
||||
content ~* $1 OR
|
||||
name ~* $1
|
||||
)
|
||||
ORDER BY
|
||||
updated DESC
|
||||
`, userID, search)
|
||||
`, search)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
|||
119
session.go
Normal file
119
session.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// Standard
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
UserID int
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
func CreateSession() (session Session, err error) {// {{{
|
||||
var rows *sql.Rows
|
||||
if rows, err = db.Query(`
|
||||
INSERT INTO public.session(uuid)
|
||||
VALUES(gen_random_uuid())
|
||||
RETURNING uuid, created`,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
rows.Scan(&session.UUID, &session.Created)
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
|
||||
func sessionUUID(r *http.Request) (string, error) {// {{{
|
||||
headers := r.Header["X-Session-Id"]
|
||||
if len(headers) > 0 {
|
||||
return headers[0], nil
|
||||
}
|
||||
return "", errors.New("Invalid session")
|
||||
}// }}}
|
||||
func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{
|
||||
var uuid string
|
||||
if uuid, err = sessionUUID(r); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
session.UUID = uuid
|
||||
if found, err = session.Retrieve(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if notFoundIsError && !found {
|
||||
err = errors.New("Invalid session")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
|
||||
func (session *Session) Retrieve() (found bool, err error) {// {{{
|
||||
var rows *sql.Rows
|
||||
if rows, err = db.Query(`
|
||||
SELECT
|
||||
uuid, user_id, created
|
||||
FROM public.session
|
||||
WHERE
|
||||
uuid = $1 AND
|
||||
created + $2::interval >= NOW()
|
||||
`,
|
||||
session.UUID,
|
||||
fmt.Sprintf("%d days", config.Session.DaysValid),
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
found = false
|
||||
if rows.Next() {
|
||||
found = true
|
||||
rows.Scan(&session.UUID, &session.UserID, &session.Created)
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{
|
||||
var rows *sql.Rows
|
||||
if rows, err = db.Query(`
|
||||
SELECT id
|
||||
FROM public.user
|
||||
WHERE
|
||||
username=$1 AND
|
||||
password=password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
|
||||
`,
|
||||
username,
|
||||
password,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
rows.Scan(&session.UserID)
|
||||
authenticated = session.UserID > 0
|
||||
}
|
||||
|
||||
if authenticated {
|
||||
_, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}// }}}
|
||||
|
||||
|
||||
// vim: foldmethod=marker
|
||||
34
sql/0013.sql
Normal file
34
sql/0013.sql
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
/* Required for the gen_random_bytes function */
|
||||
CREATE EXTENSION pgcrypto;
|
||||
|
||||
CREATE FUNCTION password_hash(salt_hex char(32), pass bytea)
|
||||
RETURNS char(96)
|
||||
LANGUAGE plpgsql
|
||||
AS
|
||||
$$
|
||||
BEGIN
|
||||
RETURN (
|
||||
SELECT
|
||||
salt_hex ||
|
||||
encode(
|
||||
sha256(
|
||||
decode(salt_hex, 'hex') || /* salt in binary */
|
||||
pass /* password */
|
||||
),
|
||||
'hex'
|
||||
)
|
||||
);
|
||||
END;
|
||||
$$;
|
||||
|
||||
/* Password has to be able to accommodate 96 characters instead of previous 64.
|
||||
* It can't be char(96), because then the password would be padded to 96 characters. */
|
||||
ALTER TABLE public."user" ALTER COLUMN "password" TYPE varchar(96) USING "password"::varchar;
|
||||
|
||||
/* Update all users with salted and hashed passwords */
|
||||
UPDATE public.user
|
||||
SET password = password_hash( encode(gen_random_bytes(16),'hex'), password::bytea);
|
||||
|
||||
/* After the password hashing, all passwords are now hex encoded 32 characters salt and 64 characters hash,
|
||||
* and the varchar type is not longer necessary. */
|
||||
ALTER TABLE public."user" ALTER COLUMN "password" TYPE char(96) USING "password"::varchar;
|
||||
|
|
@ -13,8 +13,8 @@ class App extends Component {
|
|||
this.websocket = null
|
||||
this.websocket_int_ping = null
|
||||
this.websocket_int_reconnect = null
|
||||
//this.wsConnect() // XXX
|
||||
//this.wsLoop() // XXX
|
||||
this.wsConnect()
|
||||
this.wsLoop()
|
||||
|
||||
this.session = new Session(this)
|
||||
this.session.initialize()
|
||||
|
|
@ -114,7 +114,7 @@ class App extends Component {
|
|||
}
|
||||
|
||||
if(this.session.UUID !== '')
|
||||
headers['X-Session-ID'] = this.session.UUID
|
||||
headers['X-Session-Id'] = this.session.UUID
|
||||
|
||||
fetch(url, {
|
||||
method: 'POST',
|
||||
|
|
@ -201,24 +201,9 @@ class Tree extends Component {
|
|||
this.selectedTreeNode = null
|
||||
|
||||
this.props.app.tree = this
|
||||
this.retrieve()
|
||||
}//}}}
|
||||
render({ app }) {//{{{
|
||||
let renderedTreeTrunk = this.treeTrunk.map(node=>{
|
||||
this.treeNodeComponents[node.ID] = createRef()
|
||||
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
|
||||
})
|
||||
return html`<div id="tree">${renderedTreeTrunk}</div>`
|
||||
}//}}}
|
||||
|
||||
retrieve(callback = null) {//{{{
|
||||
this.props.app.request('/node/tree', { StartNodeID: 0 })
|
||||
.then(res=>{
|
||||
this.treeNodes = {}
|
||||
this.treeNodeComponents = {}
|
||||
this.treeTrunk = []
|
||||
this.selectedTreeNode = null
|
||||
|
||||
// A tree of nodes is built. This requires the list of nodes
|
||||
// returned from the server to be sorted in such a way that
|
||||
// a parent node always appears before a child node.
|
||||
|
|
@ -250,12 +235,17 @@ class Tree extends Component {
|
|||
this.crumbsUpdateNodes()
|
||||
this.forceUpdate()
|
||||
|
||||
if(callback)
|
||||
callback()
|
||||
|
||||
})
|
||||
.catch(this.responseError)
|
||||
}//}}}
|
||||
render({ app }) {//{{{
|
||||
let renderedTreeTrunk = this.treeTrunk.map(node=>{
|
||||
this.treeNodeComponents[node.ID] = createRef()
|
||||
return html`<${TreeNode} key=${"treenode_"+node.ID} tree=${this} node=${node} ref=${this.treeNodeComponents[node.ID]} selected=${node.ID == app.startNode.ID} />`
|
||||
})
|
||||
return html`<div id="tree">${renderedTreeTrunk}</div>`
|
||||
}//}}}
|
||||
|
||||
setSelected(node) {//{{{
|
||||
if(this.selectedTreeNode)
|
||||
this.selectedTreeNode.selected.value = false
|
||||
|
|
|
|||
|
|
@ -215,14 +215,7 @@ export class NodeUI extends Component {
|
|||
let name = prompt("Name")
|
||||
if(!name)
|
||||
return
|
||||
this.node.value.create(name, nodeID=>{
|
||||
console.log('before', this.props.app.startNode)
|
||||
this.props.app.startNode = new Node(this.props.app, nodeID)
|
||||
console.log('after', this.props.app.startNode)
|
||||
this.props.app.tree.retrieve(()=>{
|
||||
this.goToNode(nodeID)
|
||||
})
|
||||
})
|
||||
this.node.value.create(name, nodeID=>this.goToNode(nodeID))
|
||||
}//}}}
|
||||
saveNode() {//{{{
|
||||
let nodeContent = this.nodeContent.current
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
export class Session {
|
||||
constructor(app) {//{{{
|
||||
constructor(app) {
|
||||
this.app = app
|
||||
this.UUID = ''
|
||||
this.initialized = false
|
||||
this.UserID = 0
|
||||
}//}}}
|
||||
}
|
||||
|
||||
initialize() {//{{{
|
||||
// Retrieving the stored session UUID, if any.
|
||||
// If one found, validate with server.
|
||||
|
||||
|
||||
// If the browser doesn't know anything about a session,
|
||||
// a call to /session/create is necessary to retrieve a session UUID.
|
||||
let uuid = window.localStorage.getItem("session.UUID")
|
||||
if (uuid === null) {
|
||||
let uuid= window.localStorage.getItem("session.UUID")
|
||||
if(uuid === null) {
|
||||
this.create()
|
||||
return
|
||||
}
|
||||
|
|
@ -24,47 +25,47 @@ export class Session {
|
|||
// A call to /session/retrieve with a session UUID validates that the
|
||||
// session is still valid and returns all session information.
|
||||
this.UUID = uuid
|
||||
this.app.request('/_session/retrieve', {})
|
||||
.then(res => {
|
||||
if (res.Error === undefined) {
|
||||
// Session exists on server.
|
||||
// Not necessarily authenticated.
|
||||
this.UserID = res.Session.UserID // could be 0
|
||||
this.initialized = true
|
||||
this.app.forceUpdate()
|
||||
} else {
|
||||
// Session has probably expired. A new is required.
|
||||
this.create()
|
||||
}
|
||||
})
|
||||
.catch(this.app.responseError)
|
||||
}//}}}
|
||||
create() {//{{{
|
||||
this.app.request('/_session/new', {})
|
||||
.then(res => {
|
||||
this.UUID = res.Session.UUID
|
||||
window.localStorage.setItem('session.UUID', this.Session.UUID)
|
||||
this.app.request('/session/retrieve', {})
|
||||
.then(res=>{
|
||||
if(res.Valid) {
|
||||
// Session exists on server.
|
||||
// Not necessarily authenticated.
|
||||
this.UserID = res.Session.UserID // could be 0
|
||||
this.initialized = true
|
||||
this.app.forceUpdate()
|
||||
})
|
||||
.catch(this.responseError)
|
||||
} else {
|
||||
// Session has probably expired. A new is required.
|
||||
this.create()
|
||||
}
|
||||
})
|
||||
.catch(this.app.responseError)
|
||||
}//}}}
|
||||
create() {//{{{
|
||||
this.app.request('/session/create', {})
|
||||
.then(res=>{
|
||||
this.UUID = res.Session.UUID
|
||||
window.localStorage.setItem('session.UUID', this.UUID)
|
||||
this.initialized = true
|
||||
this.app.forceUpdate()
|
||||
})
|
||||
.catch(this.responseError)
|
||||
}//}}}
|
||||
authenticate(username, password) {//{{{
|
||||
this.app.login.current.authentication_failed.value = false
|
||||
|
||||
this.app.request('/_session/authenticate', {
|
||||
this.app.request('/session/authenticate', {
|
||||
username,
|
||||
password,
|
||||
})
|
||||
.then(res => {
|
||||
if (res.Authenticated) {
|
||||
this.UserID = res.UserID
|
||||
this.app.forceUpdate()
|
||||
} else {
|
||||
this.app.login.current.authentication_failed.value = true
|
||||
}
|
||||
})
|
||||
.catch(this.app.responseError)
|
||||
.then(res=>{
|
||||
if(res.Authenticated) {
|
||||
this.UserID = res.Session.UserID
|
||||
this.app.forceUpdate()
|
||||
} else {
|
||||
this.app.login.current.authentication_failed.value = true
|
||||
}
|
||||
})
|
||||
.catch(this.app.responseError)
|
||||
}//}}}
|
||||
authenticated() {//{{{
|
||||
return this.UserID != 0
|
||||
|
|
|
|||
47
user.go
47
user.go
|
|
@ -1,6 +1,11 @@
|
|||
package main
|
||||
|
||||
/*
|
||||
import (
|
||||
// Standard
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err error) {
|
||||
var result sql.Result
|
||||
var rowsAffected int64
|
||||
|
|
@ -9,10 +14,10 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
|
|||
UPDATE public.user
|
||||
SET
|
||||
password = password_hash(
|
||||
/ salt in hex /
|
||||
/* salt in hex */
|
||||
ENCODE(gen_random_bytes(16), 'hex'),
|
||||
|
||||
/ password /
|
||||
/* password */
|
||||
$1::bytea
|
||||
)
|
||||
WHERE
|
||||
|
|
@ -31,4 +36,38 @@ func (session Session) UpdatePassword(currPass, newPass string) (ok bool, err er
|
|||
|
||||
return rowsAffected > 0, nil
|
||||
}
|
||||
*/
|
||||
|
||||
func createUser() (err error) {
|
||||
var username, password string
|
||||
fmt.Printf("Username: ")
|
||||
fmt.Scanln(&username)
|
||||
fmt.Printf("Password: ")
|
||||
fmt.Scanln(&password)
|
||||
|
||||
err = CreateDbUser(username, password)
|
||||
return
|
||||
}
|
||||
|
||||
func CreateDbUser(username string, password string) (err error) {
|
||||
var result sql.Result
|
||||
|
||||
result, err = db.Exec(`
|
||||
INSERT INTO public.user(username, password)
|
||||
VALUES(
|
||||
$1,
|
||||
password_hash(
|
||||
/* salt in hex */
|
||||
ENCODE(gen_random_bytes(16), 'hex'),
|
||||
|
||||
/* password */
|
||||
$2::bytea
|
||||
)
|
||||
)
|
||||
`,
|
||||
username,
|
||||
password,
|
||||
)
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue