120 lines
2.1 KiB
Go
120 lines
2.1 KiB
Go
package main
|
|
|
|
import (
|
|
// Standard
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
type Session struct {
|
|
UUID string
|
|
UserID int
|
|
Created time.Time
|
|
}
|
|
|
|
func CreateSession() (session Session, err error) {// {{{
|
|
var rows *sql.Rows
|
|
if rows, err = db.Query(`
|
|
INSERT INTO public.session(uuid)
|
|
VALUES(gen_random_uuid())
|
|
RETURNING uuid, created`,
|
|
); err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
rows.Scan(&session.UUID, &session.Created)
|
|
}
|
|
|
|
return
|
|
}// }}}
|
|
|
|
func sessionUUID(r *http.Request) (string, error) {// {{{
|
|
headers := r.Header["X-Session-Id"]
|
|
if len(headers) > 0 {
|
|
return headers[0], nil
|
|
}
|
|
return "", errors.New("Invalid session")
|
|
}// }}}
|
|
func ValidateSession(r *http.Request, notFoundIsError bool) (session Session, found bool, err error) {// {{{
|
|
var uuid string
|
|
if uuid, err = sessionUUID(r); err != nil {
|
|
return
|
|
}
|
|
|
|
session.UUID = uuid
|
|
if found, err = session.Retrieve(); err != nil {
|
|
return
|
|
}
|
|
|
|
if notFoundIsError && !found {
|
|
err = errors.New("Invalid session")
|
|
return
|
|
}
|
|
|
|
return
|
|
}// }}}
|
|
|
|
func (session *Session) Retrieve() (found bool, err error) {// {{{
|
|
var rows *sql.Rows
|
|
if rows, err = db.Query(`
|
|
SELECT
|
|
uuid, user_id, created
|
|
FROM public.session
|
|
WHERE
|
|
uuid = $1 AND
|
|
created + $2::interval >= NOW()
|
|
`,
|
|
session.UUID,
|
|
fmt.Sprintf("%d days", config.Session.DaysValid),
|
|
); err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
found = false
|
|
if rows.Next() {
|
|
found = true
|
|
rows.Scan(&session.UUID, &session.UserID, &session.Created)
|
|
}
|
|
|
|
return
|
|
}// }}}
|
|
func (session *Session) Authenticate(username, password string) (authenticated bool, err error) {// {{{
|
|
var rows *sql.Rows
|
|
if rows, err = db.Query(`
|
|
SELECT id
|
|
FROM public.user
|
|
WHERE
|
|
username=$1 AND
|
|
password=$2
|
|
`,
|
|
username,
|
|
password,
|
|
); err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
rows.Scan(&session.UserID)
|
|
authenticated = session.UserID > 0
|
|
}
|
|
|
|
if authenticated {
|
|
_, err = db.Exec("UPDATE public.session SET user_id=$1 WHERE uuid=$2", session.UserID, session.UUID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}// }}}
|
|
|
|
|
|
// vim: foldmethod=marker
|