#include <stdio.h>
#i nclude <unistd.h>
#include <fcntl.h>
#include <signal.h>
#include <sys/syscall.h>
#include <string.h>
#include <limits.h>
#include <stdlib.h>
#include <sys/ptrace.h>
#include <wait.h>
#include <stdint.h>
#include <errno.h>
#include <sys/user.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <sched.h>
#include <sys/types.h>
#include <dirent.h>
#include <sys/prctl.h>

struct gei_user_msg {
  int gei_fd;
  int pid_fd;
  pid_t pid;
};

static inline int sys_pidfd_send_signal(int pidfd, int sig, siginfo_t *info,
                                        unsigned int flags)
{
  return syscall(__NR_pidfd_send_signal, pidfd, sig, info, flags);
}
void finish_msg(struct gei_user_msg *msg) {
  close(msg->pid_fd);
  close(msg->gei_fd);
}

int get_procfd(struct gei_user_msg *msg) {
  char procpath[20];
  /* pid is in our pidns - proc must be too!
     (TODO: make our own proc mount) */
  if (sprintf(procpath, "/proc/%d", msg->pid) < 0) {
    finish_msg(msg);
    return -1;
  };
  int procdir = open(procpath, O_DIRECTORY|O_RDONLY);
  /* if the process is still alive, this can't be racy */
  if (procdir < 0 || sys_pidfd_send_signal(msg->pid_fd, 0, NULL, 0) < 0) {
    fprintf(stderr, "process died while working! %d\n", msg->pid);
    if (procdir >= 0) { close(procdir); }
    finish_msg(msg);
    return -1;
  }
  return procdir;
}

/* TODO: Configurable logic. For now, it's only the dockerize'd javas
   that we're worrying about. */
#define JAVA_NAME1 "/java"
#define NOTJAVA_PREFIX "/nix/store/"
int check_path_is_java(char *exe, int len) {
  return !strncmp(exe + len - strlen(JAVA_NAME1),
                  JAVA_NAME1, strlen(JAVA_NAME1)) &&
         strncmp(exe, NOTJAVA_PREFIX, strlen(NOTJAVA_PREFIX));
}
/* TODO: support more detailed checks (on things other than the
   executable name) */
int is_java(int procfd) {
  char buf[PATH_MAX];
  int ret = readlinkat(procfd, "exe", buf, PATH_MAX);
  if (ret < 0) { return 1; /* safe(ish) to just assume yes on err */ }
  if (ret == PATH_MAX) {
    /* slowpath when buf is too small */
    size_t bufsize = PATH_MAX;
    char *buf = NULL;
    while (ret == bufsize) {
      if (buf) { free(buf); }
      bufsize *= 2;
      buf = malloc(bufsize);
      ret = readlinkat(procfd, "exe", buf, bufsize);
      if (ret < 0) { return 1; /* safe(ish) to just assume yes on err */ }
    }
    ret = check_path_is_java(buf, ret);
    free(buf);
    return ret;
  }
  return check_path_is_java(buf, ret);
}

void ptrace_log_and_kill(char *msg, int pid) {
  perror(msg); sys_pidfd_send_signal(pid, 9, NULL, 0); exit(1);
}

struct inject_env_info {
  char *env_str;
  pid_t childpid;
};
char *find_proc_mountpoint(int procfd, struct gei_user_msg *msg) {
  int mifd = openat(procfd, "mountinfo", O_RDONLY);
  if (mifd < 0) { return NULL; }
  char buf[BUFSIZ]; char *ptr = buf; size_t len = 0;
  char *pathcopy = malloc(PATH_MAX); size_t pathsize = PATH_MAX;
  size_t pathidx = 0;

  char *ret = NULL;

  int num_spcs_found = 0;
  while (1) {
    if (len == 0) {
      len = read(mifd, &buf, BUFSIZ);
      if (len == 0) { goto ret; } /* end of file */
      if (len < 0 && errno != EINTR) { goto ret; }
      ptr = buf;
      continue;
    }
    if (*ptr == '\n') { num_spcs_found = 0; pathidx = 0; }
    if (num_spcs_found == 4) {
      if (pathidx == pathsize) {
        pathsize = pathsize * 2;
        pathcopy = realloc(pathcopy, pathsize);
      }
      pathcopy[pathidx++] = *ptr;
    }
    if (*ptr == ' ') {
      ++num_spcs_found;
      if (pathidx) { pathcopy[pathidx-1] = '\0'; }
      if (num_spcs_found == 8) {
        while (len < 6) {
          memmove(buf, ptr, len); ptr = buf;
          len += read(mifd, &buf[len], BUFSIZ-len);
          if (len == 0) { goto ret; }
          if (len < 0 && errno != EINTR) { goto ret; }
        }
        if (!strncmp(ptr, " proc ", 6)) { ret = pathcopy; goto ret; }
      }
    }
    ptr += 1;
    len -= 1;
  }

ret:
  close(mifd);
  return ret;
}
void sleep_forever(void) {
  while (1) { pause(); }
}
pid_t inject_path_helper(int gei_fd, int procfd, struct gei_user_msg *msg,
                         int preload_fd, char *procmt) {
  int pipefd[2];
  if (pipe(pipefd) < 0) { return -1; }
  pid_t cpid = fork();
  if (cpid < 0) { return -1; }
  if (cpid == 0) {
    close(pipefd[0]);
    setns(msg->pid_fd, CLONE_NEWPID|CLONE_NEWNS);
    finish_msg(msg);
    pid_t ccpid = fork();
    if (ccpid < 0) { exit(1); }
    struct stat stat;
    int mypid = getpid();
    if (fstat(procfd, &stat) < 0) { mypid = -1; };
    close(procfd);
    close(0);
    close(1);
    close(2);
    close(gei_fd);
    setresgid(stat.st_uid, stat.st_uid, stat.st_uid); /* todo */
    setresuid(stat.st_gid, stat.st_gid, stat.st_gid); /* todo */
    prctl(PR_SET_DUMPABLE, 1);
    if (!ccpid) { write(pipefd[1], &mypid, sizeof(pid_t)); }
    close(pipefd[1]);
    if (mypid < 0) { exit(1); }

    /* from now on, just looking for opportunities to terminate
     * sleep forever if anything goes wrong. */
    char *cmpp;
    if (asprintf(&cmpp, "%s/%d/fd/%d", procmt, mypid, preload_fd) < 0) {
      sleep_forever();
    }
    size_t cmplen = strlen(cmpp);
    char *cmpp2 = malloc(cmplen);
    if (!cmpp2) { sleep_forever(); }
    while (1) {
      sleep(1);
      int procfdfd = openat(procfd, "fd", O_RDONLY);
      if (procfdfd < 0) { sleep_forever(); }
      int procfdfd2 = dup(procfdfd);
      if (procfdfd2 < 0) { close(procfdfd); sleep_forever(); }
      DIR *procfddir = fdopendir(procfdfd2);
      if (!procfddir) { close(procfdfd); close(procfdfd2); sleep_forever(); }
      struct dirent *ent;
      while (ent = readdir(procfddir)) {
        if (readlinkat(procfdfd, ent->d_name, cmpp2, cmplen) < 0) { continue; }
        if (!strncmp(cmpp, cmpp2, cmplen)) { exit(0); }
      }
      closedir(procfddir);
      close(procfdfd);
    }
  }
  if (cpid > 0) {
    close(pipefd[1]);
    pid_t ret;
    read(pipefd[0], &ret, sizeof(pid_t));
    close(pipefd[0]);
    return ret;
  }
}
char *get_iei(int gei_fd, int procfd, struct gei_user_msg *msg,
              int preload_fd) {
  char *procmt = find_proc_mountpoint(procfd, msg);
  pid_t pid = inject_path_helper(gei_fd, procfd, msg, preload_fd, procmt);
  char *ret;
  if (asprintf(&ret, "LD_PRELOAD=%s/%d/fd/%d", procmt, pid, preload_fd) < 0) {
    return NULL;
  }
  return ret;
}
/* TODO: support architectures other than amd64 */
void inject_java(int gei_fd, int procfd, struct gei_user_msg *msg,
                 int preload_fd) {
  char *to_inject = get_iei(gei_fd, procfd, msg, preload_fd);
  if (!to_inject) {
    fprintf(stderr, "inject string determination failed!\n");
    sys_pidfd_send_signal(msg->pid_fd, 9, NULL, 0);
  }
  fprintf(stderr, "string to inject: '%s'\n", to_inject);

  if (ptrace(PTRACE_ATTACH, msg->pid, NULL, NULL)) {
    perror("ptrace failed!\n");
    sys_pidfd_send_signal(msg->pid_fd, 9, NULL, 0);
    finish_msg(msg);
    return;
  }
  if (sys_pidfd_send_signal(msg->pid_fd, 0, NULL, 0)) {
    fprintf(stderr, "Process died before trace %d\n", msg->pid);
    finish_msg(msg); return;
  }
  finish_msg(msg); /* let the process run until stopped */
  int status;
  /* eat a SIGTRAP for the execve finishing */
  while (1) {
    while (msg->pid != waitpid(msg->pid, &status, 0)) {}
    if (WIFEXITED(status) || WIFSIGNALED(status)) {
      fprintf(stderr, "Process died before stopping %d\n", msg->pid);
      return;
    }
    int sig = WSTOPSIG(status);
    if (WIFSTOPPED(status) && sig == SIGSTOP) { break; }
    if (WIFSTOPPED(status) && sig == SIGTRAP) { sig = 0; }
    if (ptrace(PTRACE_CONT, msg->pid, NULL, (void*)(uintptr_t)sig)) {
      /* gnu claims that this means it died */
      fprintf(stderr, "Process died before stopping (2) %d\n", msg->pid);
      return;
    }
  }

  struct user_regs_struct regs;
  if (ptrace(PTRACE_GETREGS, msg->pid, NULL, &regs)) {
    ptrace_log_and_kill("PTRACE_GETREGS", msg->pid_fd); return;
  }
  long argc = ptrace(PTRACE_PEEKDATA, msg->pid, regs.rsp, 0);
  if (errno) {
    ptrace_log_and_kill("PTRACE_PEEKDATA argc", msg->pid_fd); return;
  }

  long argvn = ptrace(PTRACE_PEEKDATA, msg->pid, regs.rsp + 8*(argc + 1), 0);
  if (errno || argvn) {
    ptrace_log_and_kill("PTRACE_PEEKDATA argvn", msg->pid_fd); return;
  }

  size_t shift_offset = 8 + (strlen(to_inject) + 1);
  /* round to a multiple of two eightbytes for alignment reasons */
  shift_offset = (shift_offset & ~ 0xf) + 16 * !!(shift_offset & 0xf);
  size_t env_entry_address = 0;

  int n_zeros_found = 0;
  int in_aux = 0;
  int in_auxval = 1;
  size_t cur_addr = regs.rsp;
  regs.rsp = regs.rsp - shift_offset;
  /* check that end of argv is where it should be */
  while (n_zeros_found < 3) {
    long entry = ptrace(PTRACE_PEEKDATA, msg->pid, cur_addr, 0);
    if (errno) {
      ptrace_log_and_kill("PTRACE_PEEKDATA entry", msg->pid_fd); return;
    }
    if (ptrace(PTRACE_POKEDATA, msg->pid, cur_addr - shift_offset, entry)) {
      ptrace_log_and_kill("PTRACE_POKEDATA entry", msg->pid_fd);
    }
    if (!entry && !(in_aux && in_auxval)) {
      n_zeros_found++;
      if (n_zeros_found == 1) {
        /* new environment pointer goes just after this */
        shift_offset = shift_offset - 8;
        env_entry_address = cur_addr - shift_offset;
      }
      if (n_zeros_found == 2) {
        in_aux = 1;
      }
    }
    cur_addr += 8;
    if (in_aux) { in_auxval = !in_auxval; }
  }
  cur_addr = cur_addr - shift_offset;
  if (ptrace(PTRACE_POKEDATA, msg->pid, env_entry_address, cur_addr)) {
    ptrace_log_and_kill("PTRACE_POKEDATA envp", msg->pid_fd);
  }
  union {
    char c[8];
    size_t i;
  } un;
  ssize_t len = strlen(to_inject) + 1;
  while (len > 0) {
    strncpy(un.c, to_inject, 8);
    if (ptrace(PTRACE_POKEDATA, msg->pid, cur_addr, un.i)) {
      ptrace_log_and_kill("PTRACE_POKEDATA un.i", msg->pid_fd);
    }
    len -= 8;
    cur_addr += 8;
    to_inject += 8;
  }

  if (ptrace(PTRACE_SETREGS, msg->pid, NULL, &regs)) {
    ptrace_log_and_kill("PTRACE_SETREGS", msg->pid_fd);
  }

  ptrace(PTRACE_DETACH, msg->pid, 0, NULL);
}

int main(int argc, char *argv[]) {
  int preload_fd = open("preload.so", O_RDONLY);
  if (preload_fd < 0) { perror("open(preload)"); return 1; }
  int f = open("/dev/global_execve_interceptor", O_RDONLY);
  if (f < 0) { perror("fopen"); return 1; }
  struct gei_user_msg msg;
  while (1) {
    ssize_t bytes = read(f, &msg, sizeof(msg));
    fprintf(stderr, "[%ld] %d %d %d\n", bytes, msg.gei_fd, msg.pid_fd, msg.pid);
    int procfd = get_procfd(&msg);
    if (procfd < 0) { continue; }
    if (is_java(procfd)) {
      inject_java(f, procfd, &msg, preload_fd);
    } else {
      finish_msg(&msg); /* inject_java() finishes as it needs to */
    }
    close(procfd);
  }
}
