177 lines
4.5 KiB
Go
177 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
// External
|
|
"github.com/google/uuid"
|
|
|
|
// Standard
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const VERSION = "v1"
|
|
|
|
var (
|
|
flagVersion bool
|
|
flagHost string
|
|
flagPort int
|
|
flagUsername string
|
|
flagPassword string
|
|
flagDatabase string
|
|
flagListenPort int
|
|
flagSequenceFilename string
|
|
flagDomain string
|
|
)
|
|
|
|
func init() {
|
|
var err error
|
|
|
|
port := 5432
|
|
if os.Getenv("INVITE_DBPORT") != "" {
|
|
port, err = strconv.Atoi(os.Getenv("INVITE_DBPORT"))
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
listen := 9876
|
|
if os.Getenv("INVITE_LISTEN") != "" {
|
|
listen, err = strconv.Atoi(os.Getenv("INVITE_LISTEN"))
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
flag.BoolVar(&flagVersion, "version", false, "Display version and exit")
|
|
flag.StringVar(&flagHost, "host", os.Getenv("INVITE_DBHOST"), "Database host")
|
|
flag.StringVar(&flagUsername, "username", os.Getenv("INVITE_DBUSERNAME"), "Database username")
|
|
flag.StringVar(&flagPassword, "password", os.Getenv("INVITE_DBPASSWORD"), "Database password")
|
|
flag.StringVar(&flagDatabase, "database", os.Getenv("INVITE_DBNAME"), "Database name")
|
|
flag.IntVar(&flagPort, "port", port, "Database port")
|
|
flag.IntVar(&flagListenPort, "listen", listen, "Web server listen port")
|
|
flag.StringVar(&flagSequenceFilename, "seq", os.Getenv("INVITE_SEQ"), "Sequence filename")
|
|
flag.StringVar(&flagDomain, "domain", os.Getenv("INVITE_DOMAIN"), "Domain FQDN")
|
|
flag.Parse()
|
|
}
|
|
|
|
func main() {
|
|
err := initDB(flagHost, flagPort, flagDatabase, flagUsername, flagPassword)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
listenOn := fmt.Sprintf("[::]:%d", flagListenPort)
|
|
http.HandleFunc("/invite", createInvite)
|
|
fmt.Printf("Listen on %s\n", listenOn)
|
|
err = http.ListenAndServe(listenOn, nil)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func sequenceNext() (seq int, err error) {
|
|
var contents []byte
|
|
var oldSequence int
|
|
contents, err = os.ReadFile(flagSequenceFilename)
|
|
if err != nil && !os.IsNotExist(err) {
|
|
return
|
|
}
|
|
|
|
if err != nil && os.IsNotExist(err) {
|
|
oldSequence = 0
|
|
} else {
|
|
fixedContents := strings.TrimSpace(string(contents))
|
|
oldSequence, err = strconv.Atoi(fixedContents)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
seq = oldSequence + 1
|
|
err = os.WriteFile(flagSequenceFilename, []byte(strconv.Itoa(seq)), 0644)
|
|
return
|
|
}
|
|
|
|
func createInvite(w http.ResponseWriter, r *http.Request) {
|
|
seq, err := sequenceNext()
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
newUUID := uuid.NewString()
|
|
email := fmt.Sprintf("%08d@example.com", seq)
|
|
expire := time.Now().Add(time.Hour * 24 * 7)
|
|
|
|
// legacy_object
|
|
key := fmt.Sprintf("invitation:uid:2:invited:%s", email)
|
|
_, err = db.Exec(` INSERT INTO public.legacy_object(_key, "type", "expireAt") VALUES($1, 'string', null)`, key)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
key = fmt.Sprintf("invitation:invited:%s", email)
|
|
_, err = db.Exec(` INSERT INTO public.legacy_object(_key, "type", "expireAt") VALUES($1, 'set', null)`, key)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
key = fmt.Sprintf("invitation:token:%s", newUUID)
|
|
_, err = db.Exec(` INSERT INTO public.legacy_object(_key, "type", "expireAt") VALUES($1, 'hash', $2)`, key, expire)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
// legacy_hash
|
|
key = fmt.Sprintf("invitation:token:%s", newUUID)
|
|
data := fmt.Sprintf(`{"email": "%s", "token": "%s", "inviter": 2, "groupsToJoin": "[]"}`, email, newUUID)
|
|
_, err = db.Exec(` INSERT INTO public.legacy_hash(_key, "data", "type") VALUES($1, $2, 'hash')`, key, data)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
// legacy_set
|
|
key = fmt.Sprintf("invitation:uid:2")
|
|
_, err = db.Exec(` INSERT INTO public.legacy_set(_key, "member", "type") VALUES($1, $2, 'set')`, key, email)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
key = fmt.Sprintf("invitation:invited:%s", email)
|
|
_, err = db.Exec(` INSERT INTO public.legacy_set(_key, "member", "type") VALUES($1, $2, 'set')`, key, newUUID)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
key = fmt.Sprintf("invitation:uid:2:invited:%s", email)
|
|
_, err = db.Exec(` INSERT INTO public.legacy_string(_key, "data", "type") VALUES($1, $2, 'string')`, key, newUUID)
|
|
if err != nil {
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
|
|
j, _ := json.Marshal(struct {
|
|
Link string
|
|
}{
|
|
fmt.Sprintf("https://%s/register?token=%s", flagDomain, newUUID),
|
|
})
|
|
w.Write(j)
|
|
|
|
return
|
|
}
|