tlsgate

TLS reverse proxy
git clone git://git.akobets.xyz/tlsgate
Log | Files | Refs | README | LICENSE

commit 7404997ab94c4695508cf0d8e75de233f2321092
parent 3b76ce165811b3d292e481c3316b5225b8e6432e
Author: Artem Kobets <artem@akobets.xyz>
Date:   Sat,  5 Sep 2020 10:25:17 +0300

rewrite with poll() instead of child/parent

Diffstat:
MMakefile | 5++---
Mmain.c | 223+++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------
Dserve.c | 140-------------------------------------------------------------------------------
Dserve.h | 4----
Msock.c | 22+++-------------------
Msock.h | 5++---
Mtlsgate.1 | 10+++++-----
7 files changed, 163 insertions(+), 246 deletions(-)

diff --git a/Makefile b/Makefile @@ -7,7 +7,6 @@ MANDIR = $(PREFIX)/share/man LIBS = -ltls ALL_CFLAGS = -std=c99 -pedantic -Wall \ - -Wmissing-prototypes -Wstrict-prototypes \ -D_POSIX_C_SOURCE=200809L \ -DVERSION=\"$(VERSION)\" \ $(CFLAGS) $(CPPFLAGS) @@ -15,9 +14,9 @@ ALL_LDFLAGS = $(LIBS) $(LDFLAGS) CC = cc -SRC = main.c serve.c sock.c util.c +SRC = main.c sock.c util.c OBJ = $(SRC:.c=.o) -HDR = serve.h sock.h util.h +HDR = sock.h util.h all: tlsgate diff --git a/main.c b/main.c @@ -1,4 +1,5 @@ #include <errno.h> +#include <poll.h> #include <signal.h> #include <stdio.h> #include <stdlib.h> @@ -7,18 +8,142 @@ #include <unistd.h> #include <sys/socket.h> +#include <tls.h> -#include "serve.h" #include "sock.h" #include "util.h" -#define SOCK_TIMEOUT_SECS 30 +// milliseconds +#define REQUEST_TIMEOUT 30000 +static void serve(struct tls *ctx, int cfd, const char *server_host, const char *server_port, const char *server_udsfile); static void sigchld(int unused); static void usage(void); char *argv0; +void +serve(struct tls *ctx, int cfd, const char *server_host, const char *server_port, const char *server_udsfile) +{ + struct tls *cctx = NULL; + int sfd = -1; + int ret; + struct pollfd pfds[2]; + char buf[BUFSIZ], *bufp; + ssize_t nread, nwritten; + + if (tls_accept_socket(ctx, &cctx, cfd) == -1) { + warn("tls_accept_socket"); + goto cleanup; + } + + // connect to server + sfd = server_udsfile + ? sock_server_uds(server_udsfile) + : sock_server_ips(server_host, server_port); + if (sfd == -1) + goto cleanup; + + // client + pfds[0].fd = cfd; + pfds[0].events = POLLIN; + // server + pfds[1].fd = sfd; + pfds[1].events = POLLIN; + + while ((ret = poll(pfds, 2, REQUEST_TIMEOUT)) > 0) { + if (pfds[0].revents & POLLERR) { + warnx("client fd error"); + goto cleanup; + } else if (pfds[0].revents & POLLIN) { + // read client->proxy + while (1) { + nread = tls_read(cctx, buf, sizeof(buf)); + if ( + nread == TLS_WANT_POLLIN || + nread == TLS_WANT_POLLOUT + ) { + continue; + } else { + break; + } + } + + if (nread == -1) { + warnx("client fd read error %s", tls_error(cctx)); + goto cleanup; + } else if (nread == 0) { + goto cleanup; + } else { + // write proxy->server + bufp = buf; + while (nread > 0) { + nwritten = write(sfd, bufp, nread); + if (nwritten == -1) { + warnx("server fd write error"); + goto cleanup; + } else { + nread -= nwritten; + bufp += nwritten; + } + } + } + } + + if (pfds[1].revents & POLLERR) { + warnx("server fd error"); + goto cleanup; + } else if (pfds[1].revents & POLLIN) { + // read server->proxy + nread = read(sfd, buf, sizeof(buf)); + + if (nread == -1) { + warnx("server fd read error"); + goto cleanup; + } else if (nread == 0) { + goto cleanup; + } else { + // write proxy->client + bufp = buf; + while (nread > 0) { + nwritten = tls_write(cctx, bufp, nread); + if (nwritten == -1) { + warnx("client fd write error"); + goto cleanup; + } else if ( + nwritten == TLS_WANT_POLLIN || + nwritten == TLS_WANT_POLLOUT + ) { + continue; + } else { + nread -= nwritten; + bufp += nwritten; + } + } + } + } + } + +cleanup: + if (cctx != NULL) { + while (1) { + ret = tls_close(cctx); + if ( + ret == TLS_WANT_POLLIN || + ret == TLS_WANT_POLLOUT + ) { + continue; + } else { + break; + } + } + tls_free(cctx); + } + close(cfd); + if (sfd != -1) + close(sfd); +} + static void sigchld(int unused) { @@ -36,6 +161,7 @@ usage(void) ); } + int main(int argc, char **argv) { @@ -43,11 +169,11 @@ main(int argc, char **argv) char *cert_file = NULL; char *key_file = NULL; char *ca_file = NULL; + char *proxy_host = NULL; + char *proxy_port = NULL; char *server_host = NULL; char *server_port = NULL; - char *client_host = NULL; - char *client_port = NULL; - char *client_udsfile = NULL; + char *server_udsfile = NULL; int maxnprocs = 512; struct rlimit rlim; @@ -71,19 +197,19 @@ main(int argc, char **argv) ca_file = optarg; break; case 'h': - server_host = optarg; + proxy_host = optarg; break; case 'p': - server_port = optarg; + proxy_port = optarg; break; case 'H': - client_host = optarg; + server_host = optarg; break; case 'P': - client_port = optarg; + server_port = optarg; break; case 'U': - client_udsfile = optarg; + server_udsfile = optarg; break; case 'n': maxnprocs = atol(optarg); @@ -97,25 +223,25 @@ main(int argc, char **argv) } } - /* cert and private key files are required */ + // cert and private key files are required if (cert_file == NULL || key_file == NULL) usage(); - /* server host can be NULL, port is required */ - if (server_port == NULL) + // proxy host can be NULL, port is required + if (proxy_port == NULL) usage(); - /* allow IPS or UDS client */ + // allow IPS or UDS server if ( - (client_host != NULL && client_udsfile != NULL) || - !(client_port != NULL || client_udsfile != NULL) + (server_host != NULL && server_udsfile != NULL) || + !(server_port != NULL || server_udsfile != NULL) ) usage(); - /* process limit */ + // process limit rlim.rlim_cur = rlim.rlim_max = maxnprocs; if (setrlimit(RLIMIT_NPROC, &rlim) == -1) err("setrlimit RLIMIT_NPROC"); - /* setup tls */ + // setup tls if ((ctx = tls_server()) == NULL) err("tls_server"); if ((config = tls_config_new()) == NULL) @@ -131,10 +257,10 @@ main(int argc, char **argv) if (tls_configure(ctx, config) == -1) err("tls_configure"); - /* setup server socket */ - fd = sock_server_ips(server_host, server_port); + // setup proxy socket + fd = sock_proxy_ips(proxy_host, proxy_port); - /* reap children */ + // reap children act.sa_handler = sigchld; sigemptyset(&sigmask); act.sa_mask = sigmask; @@ -146,7 +272,7 @@ main(int argc, char **argv) int cfd = -1; if ((cfd = accept(fd, NULL, NULL)) == -1) { - /* can be interrupted with SIGCHLD */ + // can be interrupted with SIGCHLD if (errno != EINTR) warn("accept"); continue; @@ -154,62 +280,15 @@ main(int argc, char **argv) switch (pid = fork()) { case 0: { - struct tls *cctx = NULL; - int clientfd = -1; - int ret; - close(fd); - - if (sock_set_timeout(cfd, SOCK_TIMEOUT_SECS) == -1) - goto cleanup; - - /* start tls */ - if (tls_accept_socket(ctx, &cctx, cfd) == -1) { - warn("tls_accept_socket"); - goto cleanup; - } - while (1) { - ret = tls_handshake(cctx); - if (ret == -1) { - goto cleanup; - } else if ( - ret == TLS_WANT_POLLIN || - ret == TLS_WANT_POLLOUT - ) { - continue; - } else { - break; - } - } - - /* connect to client */ - clientfd = client_udsfile - ? sock_client_uds(client_udsfile) - : sock_client_ips(client_host, client_port); - if (clientfd == -1) - goto cleanup; - if (sock_set_timeout(clientfd, SOCK_TIMEOUT_SECS) == -1) - goto cleanup; - - serve(cctx, clientfd); - -cleanup: - if (cctx != NULL) { - full_tls_close(cctx); - tls_free(cctx); - } - shutdown(cfd, SHUT_RDWR); - close(cfd); - if (clientfd != -1) - close(clientfd); + serve(ctx, cfd, server_host, server_port, server_udsfile); _exit(0); - break; } case -1: warn("fork"); - /* fallthrough */ + // fallthrough default: - /* close connection in parent */ + // close connection in parent close(cfd); break; } diff --git a/serve.c b/serve.c @@ -1,140 +0,0 @@ -#include <errno.h> -#include <stdio.h> -#include <stdlib.h> -#include <sys/wait.h> -#include <unistd.h> - -#include <sys/socket.h> - -#include "serve.h" -#include "util.h" - -#define BUFMAX 4096 - -static int full_write(int fd, char *buf, ssize_t len); -static int full_tls_write(struct tls *ctx, char *buf, ssize_t len); - -void -serve(struct tls *ctx, int fd) -{ - pid_t pid; - - pid = fork(); - if (pid == -1) { - warn("fork"); - full_tls_close(ctx); - shutdown(fd, SHUT_RDWR); - return; - } - /* reverse proxy - parent and child connect - * encrypted socket and unencrypted socket - * until one of the connections is terminated */ - if (pid == 0) { - while (1) { - char buf[BUFMAX]; - ssize_t nread; - - nread = read(fd, buf, sizeof(buf)); - if (nread == -1 && errno != EINTR) - break; - if (nread == 0) - break; - - if (full_tls_write(ctx, buf, nread) == -1) - break; - } - - full_tls_close(ctx); - shutdown(fd, SHUT_RDWR); - _exit(0); - } else { - while (1) { - char buf[BUFMAX]; - ssize_t nread; - - while (1) { - nread = tls_read(ctx, buf, sizeof(buf)); - if ( - nread == TLS_WANT_POLLIN || - nread == TLS_WANT_POLLOUT - ) { - continue; - } else { - break; - } - } - if (nread == -1) - break; - if (nread == 0) - break; - - if (full_write(fd, buf, nread) == -1) - break; - } - - full_tls_close(ctx); - shutdown(fd, SHUT_RDWR); - } -} - -int -full_tls_close(struct tls *ctx) -{ - int ret; - - while (1) { - ret = tls_close(ctx); - if ( - ret == TLS_WANT_POLLIN || - ret == TLS_WANT_POLLOUT - ) { - continue; - } else { - return ret; - } - } -} - -static int -full_write(int fd, char *buf, ssize_t len) -{ - int nwritten; - - while (len > 0) { - nwritten = write(fd, buf, len); - if (nwritten == -1) { - if (errno == EINTR) - continue; - else - return -1; - } else { - len -= nwritten; - buf += nwritten; - } - } - - return 0; -} - -static int -full_tls_write(struct tls *ctx, char *buf, ssize_t len) -{ - int nwritten; - - while (len > 0) { - nwritten = tls_write(ctx, buf, len); - if (nwritten == -1) { - return -1; - } else if ( - nwritten == TLS_WANT_POLLIN || - nwritten == TLS_WANT_POLLOUT - ) { - continue; - } else { - len -= nwritten; - buf += nwritten; - } - } - - return 0; -} diff --git a/serve.h b/serve.h @@ -1,4 +0,0 @@ -#include <tls.h> - -void serve(struct tls *ctx, int clientfd); -int full_tls_close(struct tls *ctx); diff --git a/sock.c b/sock.c @@ -12,7 +12,7 @@ #include "util.h" int -sock_server_ips(const char *host, const char *port) +sock_proxy_ips(const char *host, const char *port) { struct addrinfo hints = { .ai_flags = AI_NUMERICSERV, @@ -54,7 +54,7 @@ sock_server_ips(const char *host, const char *port) } int -sock_client_ips(const char *host, const char *port) +sock_server_ips(const char *host, const char *port) { struct addrinfo hints = { .ai_flags = AI_NUMERICSERV, @@ -94,7 +94,7 @@ sock_client_ips(const char *host, const char *port) } int -sock_client_uds(const char *file) +sock_server_uds(const char *file) { int fd; struct sockaddr_un addr; @@ -114,19 +114,3 @@ sock_client_uds(const char *file) return fd; } - -int -sock_set_timeout(int fd, int sec) -{ - struct timeval time; - - time.tv_sec = sec; - time.tv_usec = 0; - if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &time, sizeof(time)) == -1 || - setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &time, sizeof(time)) == -1) { - warn("failed to set socket timeout"); - return -1; - } - - return 0; -} diff --git a/sock.h b/sock.h @@ -1,4 +1,3 @@ +int sock_proxy_ips(const char *host, const char *port); int sock_server_ips(const char *host, const char *port); -int sock_client_ips(const char *host, const char *port); -int sock_client_uds(const char *file); -int sock_set_timeout(int fd, int sec); +int sock_server_uds(const char *file); diff --git a/tlsgate.1 b/tlsgate.1 @@ -25,19 +25,19 @@ Path to private key. Path to CA root certificates. .TP .B \-h host -TLS server's hostname. +TLS proxy hostname. .TP .B \-p port -TLS server's port number. +TLS proxy port number. .TP .B \-H host -Client's hostname. +Server hostname. .TP .B \-P port -Client's port number. +Server port number. .TP .B \-U file -Client's UNIX domain socket path. +Server UNIX domain socket path. .TP .B \-n proc-num Maximum number of threads. Default is 512.