135 lines
3.3 KiB
Go
135 lines
3.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os/exec"
|
|
)
|
|
|
|
// Server struct is a way to pass global
|
|
// server state (i.e. config) to all the handlers
|
|
type Server struct {
|
|
config Config
|
|
}
|
|
|
|
type cmdRequest struct {
|
|
Command string `json:"command"`
|
|
Args []string `json:"args"`
|
|
|
|
// in addition, could have:
|
|
// - working directory
|
|
// - environment variables
|
|
// - user (?)
|
|
// - stdin
|
|
}
|
|
|
|
type cmdResponse struct {
|
|
ExitCode int `json:"exit_code"`
|
|
Stdout string `json:"stdout"`
|
|
Stderr string `json:"stderr"`
|
|
}
|
|
|
|
// type for returning errors with some additional context
|
|
// when status codes are not enough
|
|
type errorMessageResponse struct {
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
func (s Server) isWhitelisted(command string) bool {
|
|
for _, whitelistedCommand := range s.config.WhitelistedCommmands {
|
|
if whitelistedCommand == command {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s Server) start() {
|
|
http.HandleFunc("/api/cmd", s.handleAPI)
|
|
|
|
log.Printf("Listening on port %d", s.config.Port)
|
|
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", s.config.Port), nil))
|
|
}
|
|
|
|
// POST /api/cmd
|
|
func (s Server) handleAPI(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// content-type must be application/json
|
|
if r.Header.Get("Content-Type") != "application/json" {
|
|
response := errorMessageResponse{Message: "Content-Type must be application/json"}
|
|
writeJSON(w, http.StatusUnsupportedMediaType, response)
|
|
return
|
|
}
|
|
|
|
var body cmdRequest
|
|
err := json.NewDecoder(r.Body).Decode(&body)
|
|
if err != nil {
|
|
response := errorMessageResponse{Message: "Request body is not in expected format"}
|
|
writeJSON(w, http.StatusUnprocessableEntity, response)
|
|
return
|
|
}
|
|
|
|
// check if command is whitelisted
|
|
if !s.isWhitelisted(body.Command) {
|
|
log.Printf("Command '%s' is not whitelisted\n", body.Command)
|
|
response := errorMessageResponse{
|
|
Message: fmt.Sprintf("Command '%s' is not whitelisted", body.Command),
|
|
}
|
|
writeJSON(w, http.StatusForbidden, response)
|
|
return
|
|
}
|
|
|
|
// run the command
|
|
response, err := runCommand(s.config.Timeout, body.Command, body.Args)
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
log.Printf("Command '%s' timed out", body.Command)
|
|
response := errorMessageResponse{
|
|
Message: fmt.Sprintf("Command '%s' timed out", body.Command),
|
|
}
|
|
writeJSON(w, http.StatusRequestTimeout, response)
|
|
return
|
|
} else if errors.Is(err, exec.ErrNotFound) {
|
|
log.Printf("Command '%s' not found", body.Command)
|
|
response := errorMessageResponse{
|
|
Message: fmt.Sprintf("Command '%s' not found", body.Command),
|
|
}
|
|
writeJSON(w, http.StatusNotFound, response)
|
|
return
|
|
} else {
|
|
log.Print("Error running command: ", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, response)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, statusCode int, v any) {
|
|
body, err := json.Marshal(v)
|
|
if err != nil {
|
|
log.Print("Error marshalling response body: ", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
w.Write(body)
|
|
}
|
|
|
|
func newServer() *Server {
|
|
config := Config{}
|
|
config.load()
|
|
return &Server{config: config}
|
|
}
|