tlsrp

A simple TLS reverse proxy
git clone git://nihaljere.xyz/tlsrp
Log | Files | Refs

commit ec4978ad9c269c194a1718985d9db285cc8eb275
Author: Nihal Jere <nihal@nihaljere.xyz>
Date:   Wed, 29 Apr 2020 22:12:28 -0500

initial commit, functional normal proxy

Diffstat:
AMakefile | 5+++++
Atlsrp.c | 267+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Autil.c | 164+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Autil.h | 32++++++++++++++++++++++++++++++++
4 files changed, 468 insertions(+), 0 deletions(-)

diff --git a/Makefile b/Makefile @@ -0,0 +1,5 @@ +all: + gcc tlsrp.c util.c -o tlsrp -lbsd + +clean: + rm tlsrp diff --git a/tlsrp.c b/tlsrp.c @@ -0,0 +1,267 @@ +#include <stdio.h> +#include <string.h> +#include <bsd/string.h> +#include <stdarg.h> +#include <stdlib.h> +#include <unistd.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/un.h> +#include <netdb.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <arpa/inet.h> +#include <poll.h> + +#include "util.h" + +// capped at 104 for portability +#define SUN_PATH_LENGTH 104 +#define BACKLOG 10 +#define BUF_SIZE 1024 +#define TIMEOUT 1000 +#define SERVER 0 +#define CLIENT 1 + +void +usage() +{ + puts("usage: tlsrp [-h host] -p port -f PORT"); + puts(" tlsrp -U unixsocket -f PORT"); + exit(1); +} + +// TODO add domain support? +static int +dobind(const char *host, const char *port) +{ + int sfd = -1; + struct addrinfo *results = NULL, *rp = NULL; + struct addrinfo hints = { .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM}; + + int err; + if ((err = getaddrinfo(host, port, &hints, &results)) != 0) + die("dobind: getaddrinfo: %s", gai_strerror(err)); + + for (rp = results; rp != NULL; rp = rp->ai_next) { + sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + + if (sfd == -1) + continue; + + if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) + break; + + close(sfd); + } + + if (rp == NULL) + die("failed to bind:"); + + free(results); + return sfd; +} + +static int +dounixconnect(const char *sockname) +{ + int sfd; + struct sockaddr_un saddr = {0}; + + if (strlen(sockname) > SUN_PATH_LENGTH-1) + die("unix socket path too long"); + + saddr.sun_family = AF_UNIX; + + strlcpy((char *) &saddr.sun_path, sockname, SUN_PATH_LENGTH); + + if ((sfd = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) + die("failed to create unix socket:"); + + if (connect(sfd, (struct sockaddr*)&saddr, sizeof(struct sockaddr_un)) == -1) { + close(sfd); + die("failed to connect to unix socket:"); + } + + return sfd; +} + +static int +donetworkconnect(const char* host, const char* port) +{ + int sfd = -1; + struct addrinfo *results = NULL, *rp = NULL; + struct addrinfo hints = { .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM}; + + if (getaddrinfo(host, port, &hints, &results) != 0) + die("getaddrinfo failed:"); + + for (rp = results; rp != NULL; rp = rp->ai_next) { + sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + + if (sfd == -1) + continue; + + if (connect(sfd, rp->ai_addr, rp->ai_addrlen) == 0) + break; + + close(sfd); + } + + if (rp == NULL) + warn("failed to connect:"); + + free(results); + return sfd; +} + +void dowrite(int fd, char* buf, size_t towrite) { + ssize_t written = 0; + while (towrite > 0) { + written = write(fd, buf, towrite); + if (written == -1) + die("failed to write:"); + towrite -= written; + buf += written; + } +} + +static int +serve(int serverfd, int clientfd) +{ + struct pollfd pfd[] = { + {serverfd, POLLIN | POLLOUT, 0}, + {clientfd, POLLIN | POLLOUT, 0} + }; + + char clibuf[BUF_SIZE] = {0}; + char serbuf[BUF_SIZE] = {0}; + + size_t clicount = 0, sercount = 0; + + while (poll(pfd, 2, TIMEOUT) != 0) { + if ((pfd[CLIENT].revents | pfd[SERVER].revents) & POLLNVAL) + return -1; + + if ((pfd[CLIENT].revents & POLLIN)) { + clicount = read(clientfd, clibuf, BUF_SIZE); + if (clicount == -1) { + die("client read failed:"); + return -2; + } + } + + if ((pfd[SERVER].revents & POLLIN)) { + sercount = read(serverfd, serbuf, BUF_SIZE); + if (sercount == -1) { + die("server read failed:"); + return -3; + } + } + + if ((pfd[SERVER].revents & POLLOUT) && clicount > 0) { + dowrite(serverfd, clibuf, clicount); + clicount = 0; + } + + if ((pfd[CLIENT].revents & POLLOUT) && sercount > 0) { + dowrite(clientfd, serbuf, sercount); + sercount = 0; + } + + if ((pfd[CLIENT].revents | pfd[SERVER].revents) & POLLHUP) + if (clicount == 0 && sercount == 0) + break; + + if ((pfd[CLIENT].revents | pfd[SERVER].revents) & POLLERR) + break; + } + return 0; +} + +int +main(int argc, char* argv[]) +{ + int to_server = 0, to_client = 0; + struct sockaddr_storage client_sa, server_sa = {0}; + socklen_t client_sa_len = 0; + int serverfd, bindfd; + char *usock = NULL, + *host = NULL, + *backport = NULL, + *frontport = NULL; + + if (argc < 3) + usage(); + + // TODO make parameter format enforcement stricter + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-U") == 0) + usock = argv[++i]; + else if (strcmp(argv[i], "-h") == 0) + host = argv[++i]; + else if (strcmp(argv[i], "-p") == 0) + backport = argv[++i]; + else if (strcmp(argv[i], "-f") == 0) + frontport = argv[++i]; + else + usage(); + } + + if (usock && (host || backport)) + die("cannot use both unix and network socket"); + + bindfd = dobind(NULL, frontport); + + if (listen(bindfd, BACKLOG) == -1) { + close(bindfd); + die("could not start listen:"); + } + + + pid_t pid; + + while (1) { + if ((to_client = accept(bindfd, (struct sockaddr*) &client_sa, + &client_sa_len)) == -1) { + warn("could not accept connection:"); + } + + switch ((pid = fork())) { + case -1: + warn("fork:"); + case 0: + if (usock) + to_server = dounixconnect(usock); + else + to_server = donetworkconnect(host, backport); + + if (to_server) + serve(to_server, to_client); + close(to_server); + close(to_client); + close(bindfd); + exit(0); + break; + default: + close(to_client); + } + } + + // TODO Initialize + // - validate addresses + // - create sockets + // - bind + // - listen + // TODO Serve + // - fork + // - accept connect + // - serve + // - close + // TODO Shutdown + // - close sockets + // - unlink +} + diff --git a/util.c b/util.c @@ -0,0 +1,164 @@ +/* See LICENSE file for copyright and license details. */ +#include <errno.h> +#include <limits.h> +#include <stdarg.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/types.h> +#include <time.h> + +#ifdef __OpenBSD__ +#include <unistd.h> +#endif /* __OpenBSD__ */ + +#include "util.h" + +char *argv0; + +static void +verr(const char *fmt, va_list ap) +{ + if (argv0 && strncmp(fmt, "usage", sizeof("usage") - 1)) { + fprintf(stderr, "%s: ", argv0); + } + + vfprintf(stderr, fmt, ap); + + if (fmt[0] && fmt[strlen(fmt) - 1] == ':') { + fputc(' ', stderr); + perror(NULL); + } else { + fputc('\n', stderr); + } +} + +void +warn(const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + verr(fmt, ap); + va_end(ap); +} + +void +die(const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + verr(fmt, ap); + va_end(ap); + + exit(1); +} + +void +epledge(const char *promises, const char *execpromises) +{ + (void)promises; + (void)execpromises; + +#ifdef __OpenBSD__ + if (pledge(promises, execpromises) == -1) { + die("pledge:"); + } +#endif /* __OpenBSD__ */ +} + +void +eunveil(const char *path, const char *permissions) +{ + (void)path; + (void)permissions; + +#ifdef __OpenBSD__ + if (unveil(path, permissions) == -1) { + die("unveil:"); + } +#endif /* __OpenBSD__ */ +} + +char * +timestamp(time_t t, char buf[TIMESTAMP_LEN]) +{ + strftime(buf, TIMESTAMP_LEN, "%a, %d %b %Y %T GMT", gmtime(&t)); + + return buf; +} + +int +esnprintf(char *str, size_t size, const char *fmt, ...) +{ + va_list ap; + int ret; + + va_start(ap, fmt); + ret = vsnprintf(str, size, fmt, ap); + va_end(ap); + + return (ret < 0 || (size_t)ret >= size); +} + +#define INVALID 1 +#define TOOSMALL 2 +#define TOOLARGE 3 + +long long +strtonum(const char *numstr, long long minval, long long maxval, + const char **errstrp) +{ + long long ll = 0; + int error = 0; + char *ep; + struct errval { + const char *errstr; + int err; + } ev[4] = { + { NULL, 0 }, + { "invalid", EINVAL }, + { "too small", ERANGE }, + { "too large", ERANGE }, + }; + + ev[0].err = errno; + errno = 0; + if (minval > maxval) { + error = INVALID; + } else { + ll = strtoll(numstr, &ep, 10); + if (numstr == ep || *ep != '\0') + error = INVALID; + else if ((ll == LLONG_MIN && errno == ERANGE) || ll < minval) + error = TOOSMALL; + else if ((ll == LLONG_MAX && errno == ERANGE) || ll > maxval) + error = TOOLARGE; + } + if (errstrp != NULL) + *errstrp = ev[error].errstr; + errno = ev[error].err; + if (error) + ll = 0; + + return ll; +} + +/* + * This is sqrt(SIZE_MAX+1), as s1*s2 <= SIZE_MAX + * if both s1 < MUL_NO_OVERFLOW and s2 < MUL_NO_OVERFLOW + */ +#define MUL_NO_OVERFLOW ((size_t)1 << (sizeof(size_t) * 4)) + +void * +reallocarray(void *optr, size_t nmemb, size_t size) +{ + if ((nmemb >= MUL_NO_OVERFLOW || size >= MUL_NO_OVERFLOW) && + nmemb > 0 && SIZE_MAX / nmemb < size) { + errno = ENOMEM; + return NULL; + } + return realloc(optr, size * nmemb); +} diff --git a/util.h b/util.h @@ -0,0 +1,32 @@ +/* See LICENSE file for copyright and license details. */ +#ifndef UTIL_H +#define UTIL_H + +#include <regex.h> +#include <stddef.h> +#include <time.h> + +#undef MIN +#define MIN(x,y) ((x) < (y) ? (x) : (y)) +#undef MAX +#define MAX(x,y) ((x) > (y) ? (x) : (y)) +#undef LEN +#define LEN(x) (sizeof (x) / sizeof *(x)) + +extern char *argv0; + +void warn(const char *, ...); +void die(const char *, ...); + +void epledge(const char *, const char *); +void eunveil(const char *, const char *); + +#define TIMESTAMP_LEN 30 + +char *timestamp(time_t, char buf[TIMESTAMP_LEN]); +int esnprintf(char *, size_t, const char *, ...); + +void *reallocarray(void *, size_t, size_t); +long long strtonum(const char *, long long, long long, const char **); + +#endif /* UTIL_H */