diff --git a/script.go b/script.go index 13fe84f..e8f98ab 100644 --- a/script.go +++ b/script.go @@ -27,6 +27,7 @@ type Hook struct { Node Node Script Script SSH string + Env map[string]string } func GetScript(scriptID int) (script Script, err error) { // {{{ @@ -161,7 +162,7 @@ func SearchScripts(search string) (scripts []Script, err error) { // {{{ return } // }}} func HookScript(nodeID, scriptID int) (err error) { // {{{ - _, err = db.Exec(`INSERT INTO hook(node_id, script_id, ssh) VALUES($1, $2, '')`, nodeID, scriptID) + _, err = db.Exec(`INSERT INTO hook(node_id, script_id, ssh, env) VALUES($1, $2, '')`, nodeID, scriptID) return } // }}} @@ -173,6 +174,7 @@ func GetHook(hookID int) (hook Hook, err error) { // {{{ SELECT h.id, h.ssh, + h.env, (SELECT to_json(node) FROM node WHERE id = h.node_id) AS node, (SELECT to_json(script) FROM script WHERE id = h.script_id) AS script FROM hook h @@ -243,7 +245,8 @@ func ScheduleHook(hookID int) (err error) { // {{{ return } - _, err = db.Exec(`INSERT INTO execution(script_log_id, data, ssh) VALUES($1, $2, $3)`, scriptLogID, nodeData, hook.SSH) + j, _ := json.Marshal(hook.Env) + _, err = db.Exec(`INSERT INTO execution(script_log_id, data, ssh, env) VALUES($1, $2, $3, $4)`, scriptLogID, nodeData, hook.SSH, j) if err != nil { err = werr.Wrap(err) return diff --git a/script_scheduler.go b/script_scheduler.go index 96878a7..96fc913 100644 --- a/script_scheduler.go +++ b/script_scheduler.go @@ -7,12 +7,16 @@ import ( // Standard "bytes" "database/sql" + "encoding/json" "fmt" "os/exec" "strings" "time" ) +const ENV_NAME = 0 +const SCRIPT_NAME = 1 + type ScriptScheduler struct { EventQueue chan string } @@ -24,6 +28,7 @@ type ScriptExecution struct { Source []byte Data []byte SSH string + Env []byte OutputStdout sql.NullString `db:"output_stdout"` OutputStderr sql.NullString `db:"output_stderr"` ExitCode sql.NullInt16 @@ -70,30 +75,32 @@ func (self ScriptScheduler) HandleNextExecution() { // {{{ se.Update() logger.Info("script_scheduler", "op", "execute", "id", se.ID) - fname, err := se.GetScriptTempFilename() - if err != nil { - err = werr.Wrap(err) - logger.Error("script_execution", "op", "get_script_temp_filename", "id", se.ID, "error", err) - return - } - err = se.UploadScript(fname) + var fnames []string + fnames, err = se.UploadScript() if err != nil { err = werr.Wrap(err) logger.Error("script_execution", "op", "upload_script", "id", se.ID, "error", err) return } - err = se.RunScript(fname) + err = se.UploadEnv(fnames[ENV_NAME], fnames[SCRIPT_NAME]) + if err != nil { + err = werr.Wrap(err) + logger.Error("script_execution", "op", "upload_env", "id", se.ID, "error", err) + return + } + + err = se.RunScript(fnames[ENV_NAME]) if err != nil { err = werr.Wrap(err) logger.Error("script_execution", "op", "run_script", "id", se.ID, "error", err) return } - se.SSHCommand([]byte{}, false, fmt.Sprintf("rm %s", fname)) + se.SSHCommand([]byte{}, false, fmt.Sprintf("rm %s %s", fnames[ENV_NAME], fnames[SCRIPT_NAME])) - logger.Info("script_scheduler", "op", "handled", "script", fname) + logger.Info("script_scheduler", "op", "handled", "script", fnames[SCRIPT_NAME]) } // }}} func (self ScriptScheduler) GetNextExecution() (e ScriptExecution, err error) { // {{{ row := db.QueryRowx(` @@ -103,6 +110,7 @@ func (self ScriptScheduler) GetNextExecution() (e ScriptExecution, err error) { time_end, data, ssh, + env, output_stdout, output_stderr, exitcode, @@ -189,20 +197,49 @@ func (se *ScriptExecution) SSHCommand(stdin []byte, log bool, args ...string) (s return stdout.String(), nil } // }}} -func (se *ScriptExecution) GetScriptTempFilename() (fname string, err error) { // {{{ - fname, err = se.SSHCommand([]byte{}, true, "mktemp -t datagraph.XXXXXX") +func (se *ScriptExecution) UploadScript() (fnames []string, err error) { // {{{ + var filenames string + filenames, err = se.SSHCommand( + se.Source, + true, + `sh -c 'RUNENV=$(mktemp -t datagraph.XXXXXX) && SCRIPT=$(mktemp -t datagraph.XXXXXX) && touch $RUNENV $SCRIPT && chmod 700 $RUNENV $SCRIPT && cat >$SCRIPT && echo $RUNENV $SCRIPT'`, + ) + if err != nil { + err = werr.Wrap(err) + } + + fnames = strings.Split(strings.TrimSpace(filenames), " ") + + if len(fnames) != 2 { + err = werr.New("Invalid temp filename count: %d", len(fnames)) + return + } + + return fnames[:2], nil +} // }}} +func (se *ScriptExecution) UploadEnv(envFname, scriptFname string) (err error) { // {{{ + env := make(map[string]string) + err = json.Unmarshal(se.Env, &env) if err != nil { err = werr.Wrap(err) return } - fname = strings.TrimSpace(fname) - return -} // }}} -func (se *ScriptExecution) UploadScript(fname string) (err error) { // {{{ - _, err = se.SSHCommand(se.Source, true, fmt.Sprintf("sh -c 'touch %s && chmod 700 %s && cat >%s'", fname, fname, fname)) + + var script = "#!/bin/sh\n\n" + for key, val := range env { + script = script + fmt.Sprintf("export %s=\"%s\"\n", key, strings.ReplaceAll(val, `"`, `\"`)) + } + script = script + "\n" + scriptFname + "\n" + + _, err = se.SSHCommand( + []byte(script), + true, + fmt.Sprintf(`sh -c 'cat >%s'`, envFname), + ) if err != nil { err = werr.Wrap(err) } + return } // }}} func (se *ScriptExecution) RunScript(fname string) (err error) { // {{{ diff --git a/sql/0015.sql b/sql/0015.sql new file mode 100644 index 0000000..dedf69f --- /dev/null +++ b/sql/0015.sql @@ -0,0 +1 @@ +ALTER TABLE public.hook ADD env jsonb DEFAULT '{}' NOT NULL;