/*	$NetBSD$	*/

/*
 *	fopen_as_user() - open a file with the credentials of a user
 *
 * This function opens the file whos name is in the string pointed to by
 * "path", and associates a stream of "mode" with it.
 *
 * The call to open(2) is done with the credentials of "uid" and "gid".
 *
 * In this implementation the open(2) call is done in a child process after it
 * has called setuid(2) and setgid(2) and the resulting file descriptor is
 * passed securely back to the caller on a file descriptor opened through a
 * socketpair() in the AF_LOCAL domain.
 */

#include <sys/cdefs.h>
#if defined(LIBC_SCCS) && !defined(lint)
#if 0
static char sccsid[] = "@(#)fopen_as_user.c	1.0 (Planix) 2003";
#else
__RCSID("$NetBSD$");
#endif
#endif /* LIBC_SCCS and not lint */

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include "reentrant.h"
#include "local.h"

static ssize_t read_fd(int, void *, size_t, int *);
static ssize_t write_fd(int, void *, size_t, int);


/*
 * file descriptor passing helper function
 *
 * More or less copied from W. Richard Stevens' "UNIX Network Programming",
 * Vol. 1 (2nd ed), p. 387 (obsolete code for old recvmsg() deleted, and
 * converted to use malloc() because CMSG_SPACE() is not a compile-time
 * constant on NetBSD)
 */
static ssize_t
read_fd(fd, ptr, nbytes, precvfd)
	int fd;				/* AF_LOCAL socket */
	void *ptr;			/* for ancilliary data */
	size_t nbytes;			/* size of ptr storage */
	int *precvfd;			/* pointer for storing descriptor */
{
	struct msghdr	msg;
	struct iovec	iov;
	ssize_t		n;
	char           *control;
	struct cmsghdr *cmptr;
	int             oerrno;

	*precvfd = -1;			/* assume the worst.... */

	if (!(control = malloc(CMSG_SPACE(sizeof(*precvfd)))))
		return -1;

	memset(&msg, 0, sizeof(msg));
	msg.msg_control = control;
	msg.msg_controllen = (socklen_t) CMSG_SPACE(sizeof(*precvfd)); /* CMSG_LEN()??? */
	msg.msg_name = NULL;
	msg.msg_namelen = 0;
	msg.msg_flags = 0;

	iov.iov_base = ptr;
	iov.iov_len = nbytes;
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;

	n = recvmsg(fd, &msg, 0);
	oerrno = errno;
	if (n <= 0) {
		free(control);
		errno = oerrno;
		return -1;
	}

	cmptr = CMSG_FIRSTHDR(&msg);
	if (!cmptr || cmptr->cmsg_len != CMSG_LEN(sizeof(*precvfd))) {
		free(control);
		errno = EBADF;		/* descriptor was not passed */
		return -1;
	}
	if (cmptr->cmsg_level != SOL_SOCKET) {
		free(control);
		errno = EBADF;
		return -1;
	}
	if (cmptr->cmsg_type != SCM_RIGHTS) {
		free(control);
		errno = EBADF;
		return -1;
	}
	*precvfd = *((int *) ((void *) CMSG_DATA(cmptr)));

	free(control);
	errno = oerrno;

	return (n);
}

/*
 * file descriptor passing helper function
 *
 * More or less copied from W. Richard Stevens' "UNIX Network Programming",
 * Vol. 1 (2nd ed), p. 389 (obsolete code for old recvmsg() deleted, and
 * converted to use malloc() because CMSG_SPACE() is not a compile-time
 * constant on NetBSD)
 */
static ssize_t
write_fd(fd, ptr, nbytes, sendfd)
	int fd;				/* AF_LOCAL socket */
	void *ptr;			/* for ancilliary data */
	size_t nbytes;			/* size of ancilliary data */
	int sendfd;			/* descriptor to send */
{
	struct msghdr	msg;
	struct iovec	iov;
	char           *control;
	struct cmsghdr *cmptr;
	int             oerrno;
	ssize_t         rv;

	if (!(control = malloc(CMSG_SPACE(sizeof(sendfd)))))
		return -1;

	memset(&msg, 0, sizeof(msg));
	msg.msg_control = control;
	msg.msg_controllen = (socklen_t) CMSG_LEN(sizeof(sendfd));
	msg.msg_name = NULL;
	msg.msg_namelen = 0;
	msg.msg_flags = 0;

	cmptr = CMSG_FIRSTHDR(&msg);
	cmptr->cmsg_len = (socklen_t) CMSG_LEN(sizeof(sendfd));
	cmptr->cmsg_level = SOL_SOCKET;
	cmptr->cmsg_type = SCM_RIGHTS;
	*((int *) ((void *) CMSG_DATA(cmptr))) = sendfd;

	iov.iov_base = ptr;
	iov.iov_len = nbytes;
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;

	rv = sendmsg(fd, &msg, 0);
	oerrno = errno;

	free(control);
	errno = oerrno;

	return rv;
}

/*
 *	fopen_as_user() - open a file using uid:gid privileges only.
 *
 * This is the magic, forking, file-descriptor passing, version!
 *
 * The basic concept for descriptor passing has been more or less copied from
 * W. Richard Stevens' "UNIX Network Programming", Vol. 1 (2nd ed), p. 385
 *
 * XXX WARNING: This is an incomplete implementation! (e.g. no initgroups())
 */
FILE *
fopen_as_user(path, mode, uid, gid)
	const char *path;
	const char *mode;
	uint32_t uid;
	uint32_t gid;
{
	int fd;
	int pid;
	int sockfd[2];
	int status;
	FILE *fp = NULL;
	char emptystr[1] = "";

	if (socketpair(AF_LOCAL, SOCK_STREAM, 0, sockfd) < 0)
		return NULL;

	switch ((pid = fork())) {
	case -1: {
		int oerrno = errno;

		/* fork() error */
		close(sockfd[0]);
		close(sockfd[1]);
		errno = oerrno;
		return NULL;
		/* NOTREACHED */
	}
	case 0: {
		int cfd;
		int oflags;

# if (ELAST > 254)
#  include "ERORR:  This code returns errno via _exit() and ELAST > 254!"
# endif
		/* in the child process */

		/*
		 * Take care to use _exit() so as to avoid any side-effects of
		 * exit(), such as atexit()s registered by the parent process.
		 *
		 * Note this means we have to be very careful to always close
		 * what we open!
		 */
		close(sockfd[0]);

		if (setgid(gid) != 0) {
			int oerrno = errno;

			close(sockfd[1]);
			_exit(oerrno + 1);
		}
		if (setuid(uid) != 0) {
			int oerrno = errno;

			close(sockfd[1]);
			_exit(oerrno + 1);
		}
		if (__sflags(mode, &oflags) == 0) { /* magic *BSD stdio internals */
			int oerrno = errno;

			close(sockfd[1]);
			_exit(oerrno + 1);
		}
		if ((cfd = open(path, oflags, DEFFILEMODE)) < 0) {
			int oerrno = errno;

			close(sockfd[1]);
			_exit(oerrno + 1);
		}
		/*
		 * When sending a descriptor across a stream pipe we always
		 * send at least one byte of data, even if the receiver does
		 * nothing with the data, otherwise the receiver cannot tell
		 * whether a return value of 0 from read_fd() means "no data
		 * (but possibly a descriptor)", or just "end of file".
		 */
		if (write_fd(sockfd[1], emptystr, 1, cfd) < 0) {
			int oerrno = errno;

			close(sockfd[1]);
			close(cfd);
			_exit(oerrno + 1);
		}
		close(cfd);
		close(sockfd[1]);
		_exit(0);
		/* NOTREACHED */
	}
	default:
		/* in the parent process */
		close(sockfd[1]);
		if (waitpid(pid, &status, 0) < 0) { /* WUNTRACED??? */
			int oerrno = errno;

			close(sockfd[0]);
			errno = oerrno;
			return NULL;
		}
		if (WIFEXITED(status)) {
			char ch;

			if (WEXITSTATUS(status)) {
				close(sockfd[0]);
				errno = WEXITSTATUS(status) - 1;
				return NULL;
			}
			if (read_fd(sockfd[0], &ch, (size_t) 1, &fd) <= 0) {
				int oerrno = errno;

				close(sockfd[0]);
				errno = oerrno;
				return NULL;
			}
			if (fd < 0) {
				int oerrno = errno;

				close(sockfd[0]);
				errno = oerrno;
				return NULL;
			}
			close(sockfd[0]);
		} else {
			close(sockfd[0]);
			errno = EINTR;
			return NULL;
		}
	}
	if (!(fp = fdopen(fd, mode))) {
		int oerrno = errno;

		close(fd);
		errno = oerrno;
		return NULL;
	}

	return fp;
}
