commit 85c71d237299a8d3683e4d685b951e8eedda800d Author: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu Aug 31 15:15:38 2023 +0530 initial commit diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a60e17c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +# Dockerfile for cmd-api +# GOOS=linux GOARCH=amd64 +# Copy the source code into some directory +# Build it with GOOS and GOARCH +# Entrypoint would be this binary +# The configuration variables can be provided during docker run with "-e" option + +FROM go:... + +COPY . /app + +WORKDIR /app + +RUN "GOOS=linux GOARCH=amd64 go build -o ./cmd-api" + +ENTRYPOINT "./cmd-api" diff --git a/config.go b/config.go new file mode 100644 index 0000000..5a313cf --- /dev/null +++ b/config.go @@ -0,0 +1,52 @@ +package main + +import ( + "log" + "os" + "strconv" + "strings" + "time" +) + +// Config is a minimal config struct to hold the +// server port and a list of whitelisted commands +type Config struct { + Port int + WhitelistedCommmands []string + Timeout time.Duration +} + +func (config *Config) load() { + // port + portStr, ok := os.LookupEnv("PORT") + if !ok { + portStr = "8080" + } + port, err := strconv.Atoi(portStr) + if err != nil { + // simply print error and exit + log.Print("Error parsing port: ", err) + os.Exit(1) + } + config.Port = port + + // whitelisted commands + whitelistedCommandsStr, ok := os.LookupEnv("WHITELISTED_COMMANDS") + if !ok { + whitelistedCommandsStr = "ls,pwd,echo,cat,touch,rm,mkdir" + } + whitelistedCommandsStr = strings.TrimSpace(whitelistedCommandsStr) + config.WhitelistedCommmands = strings.Split(whitelistedCommandsStr, ",") + + // timeout + timeoutStr, ok := os.LookupEnv("TIMEOUT") + if !ok { + timeoutStr = "5s" + } + timeout, err := time.ParseDuration(timeoutStr) + if err != nil { + log.Print("Error parsing timeout: ", err) + os.Exit(1) + } + config.Timeout = timeout +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6e1e631 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.kaustubh.page/kaustubh/cmd-api + +go 1.19 diff --git a/main.go b/main.go new file mode 100644 index 0000000..496d032 --- /dev/null +++ b/main.go @@ -0,0 +1,6 @@ +package main + +func main() { + server := newServer() + server.start() +} diff --git a/runner.go b/runner.go new file mode 100644 index 0000000..80a8d95 --- /dev/null +++ b/runner.go @@ -0,0 +1,41 @@ +package main + +import ( + "context" + "errors" + "log" + "os/exec" + "strings" + "time" +) + +func runCommand(timeout time.Duration, command string, args []string) (*cmdResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, command, args...) + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + + log.Printf("Error after running command: %v", err) + + // if the command exits with a non-zero exit code, + // we still want to return a valid cmdResponse + var exitError *exec.ExitError + if err != nil && !errors.As(err, &exitError) { + return nil, err + } + + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return nil, ctx.Err() + } + + response := cmdResponse{ + ExitCode: cmd.ProcessState.ExitCode(), + Stdout: stdout.String(), + Stderr: stderr.String(), + } + return &response, nil +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..a4b1911 --- /dev/null +++ b/server.go @@ -0,0 +1,134 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "os/exec" +) + +// Server struct is a way to pass global +// server state (i.e. config) to all the handlers +type Server struct { + config Config +} + +type cmdRequest struct { + Command string `json:"command"` + Args []string `json:"args"` + + // in addition, could have: + // - working directory + // - environment variables + // - user (?) + // - stdin +} + +type cmdResponse struct { + ExitCode int `json:"exit_code"` + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` +} + +// type for returning errors with some additional context +// when status codes are not enough +type errorMessageResponse struct { + Message string `json:"message"` +} + +func (s Server) isWhitelisted(command string) bool { + for _, whitelistedCommand := range s.config.WhitelistedCommmands { + if whitelistedCommand == command { + return true + } + } + return false +} + +func (s Server) start() { + http.HandleFunc("/api/cmd", s.handleAPI) + + log.Printf("Listening on port %d", s.config.Port) + log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", s.config.Port), nil)) +} + +// POST /api/cmd +func (s Server) handleAPI(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + // content-type must be application/json + if r.Header.Get("Content-Type") != "application/json" { + response := errorMessageResponse{Message: "Content-Type must be application/json"} + writeJSON(w, http.StatusUnsupportedMediaType, response) + return + } + + var body cmdRequest + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + response := errorMessageResponse{Message: "Request body is not in expected format"} + writeJSON(w, http.StatusUnprocessableEntity, response) + return + } + + // check if command is whitelisted + if !s.isWhitelisted(body.Command) { + log.Printf("Command '%s' is not whitelisted\n", body.Command) + response := errorMessageResponse{ + Message: fmt.Sprintf("Command '%s' is not whitelisted", body.Command), + } + writeJSON(w, http.StatusForbidden, response) + return + } + + // run the command + response, err := runCommand(s.config.Timeout, body.Command, body.Args) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + log.Printf("Command '%s' timed out", body.Command) + response := errorMessageResponse{ + Message: fmt.Sprintf("Command '%s' timed out", body.Command), + } + writeJSON(w, http.StatusRequestTimeout, response) + return + } else if errors.Is(err, exec.ErrNotFound) { + log.Printf("Command '%s' not found", body.Command) + response := errorMessageResponse{ + Message: fmt.Sprintf("Command '%s' not found", body.Command), + } + writeJSON(w, http.StatusNotFound, response) + return + } else { + log.Print("Error running command: ", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + } + + writeJSON(w, http.StatusOK, response) +} + +func writeJSON(w http.ResponseWriter, statusCode int, v any) { + body, err := json.Marshal(v) + if err != nil { + log.Print("Error marshalling response body: ", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + w.Write(body) +} + +func newServer() *Server { + config := Config{} + config.load() + return &Server{config: config} +}