tlsgate

TLS proxy
git clone git://git.akobets.xyz/tlsgate
Log | Files | Refs | README | LICENSE

main.c (10406B)


      1 #include <errno.h>
      2 #include <poll.h>
      3 #include <signal.h>
      4 #include <stdio.h>
      5 #include <stdlib.h>
      6 #include <string.h>
      7 #include <time.h>
      8 #include <unistd.h>
      9 
     10 #include <netinet/in.h>
     11 #include <tls.h>
     12 
     13 #include "sock.h"
     14 #include "util.h"
     15 
     16 enum {
     17   MODE_NONE,
     18   MODE_TLS_CLIENT,
     19   MODE_TLS_SERVER
     20 };
     21 
     22 struct settings {
     23   int mode;
     24   int timeout;
     25   struct tls *tls_ctx;
     26   char *server_host;
     27   char *server_port;
     28   char *server_udsfile;
     29 };
     30 
     31 static void serve(const struct settings *s, int cfd, const struct sockaddr_storage *addr);
     32 static void cleanup(void);
     33 static void sigcleanup(int sig);
     34 static void handle_termsignals(void (*func)(int));
     35 static void usage(void);
     36 
     37 char *argv0;
     38 static char *proxy_udsfile;
     39 
     40 void
     41 serve(const struct settings *s, int cfd, const struct sockaddr_storage *addr)
     42 {
     43   time_t t;
     44   char addr_str[INET6_ADDRSTRLEN]; // > INET_ADDRSTRLEN
     45   char tstmp[21];
     46 
     47   int sfd = -1, fd_tls, fd_norm;
     48   struct tls *tls_ctx = NULL;
     49   int nready;
     50   int ret;
     51   struct pollfd pfds[2], pfd[1];
     52   char buf[BUFSIZ], *bufp;
     53   ssize_t nread, nwritten;
     54   int poll_timeout;
     55 
     56   // log
     57   t = time(NULL);
     58   if (strftime(tstmp, sizeof(tstmp), "%Y-%m-%dT%H:%M:%SZ", gmtime(&t)) == 0) {
     59     warnx("strftime: Exceeded buffer capacity");
     60     goto cleanup;
     61   }
     62   if (sock_get_addr_str(addr, addr_str, sizeof(addr_str)) == -1)
     63     goto cleanup;
     64   printf("%s %s\n", tstmp, addr_str);
     65 
     66   // connect to server
     67   sfd = s->server_udsfile
     68     ? sock_server_uds(s->server_udsfile)
     69     : sock_server_ips(s->server_host, s->server_port);
     70   if (sfd == -1)
     71     goto cleanup;
     72 
     73   if (s->mode == MODE_TLS_CLIENT) {
     74     fd_tls = cfd;
     75     fd_norm = sfd;
     76   } else {
     77     // MODE_TLS_SERVER
     78     fd_tls = sfd;
     79     fd_norm = cfd;
     80   }
     81 
     82   if (sock_set_nonblock(fd_tls) == -1)
     83     goto cleanup;
     84   if (sock_set_nonblock(fd_norm) == -1)
     85     goto cleanup;
     86 
     87   if (s->mode == MODE_TLS_CLIENT) {
     88     if (tls_accept_socket(s->tls_ctx, &tls_ctx, fd_tls) == -1) {
     89       warn("tls_accept_socket");
     90       goto cleanup;
     91     }
     92   } else {
     93     // MODE_TLS_SERVER
     94     if (tls_connect_socket(s->tls_ctx, fd_tls, s->server_host) == -1) {
     95       warn("tls_connect_socket");
     96       goto cleanup;
     97     }
     98     tls_ctx = s->tls_ctx;
     99   }
    100 
    101   pfds[0].fd = fd_tls;
    102   pfds[0].events = POLLIN | POLLOUT;
    103   pfds[1].fd = fd_norm;
    104   pfds[1].events = POLLIN;
    105 
    106   poll_timeout = s->timeout > 0 ? s->timeout * 1000 : -1;
    107   // when checking POLLIN, also check POLLHUP
    108   while (poll(pfds, 2, poll_timeout) > 0) {
    109     // tls->proxy->normal
    110     if (pfds[0].revents & (pfds[0].events | POLLHUP)) {
    111       // read tls->proxy
    112       while (1) {
    113         nread = tls_read(tls_ctx, buf, sizeof(buf));
    114         if (nread == TLS_WANT_POLLIN) {
    115           pfds[0].events = POLLIN;
    116           break;
    117         } else if (nread == TLS_WANT_POLLOUT) {
    118           pfds[0].events = POLLOUT;
    119           break;
    120         } else if (nread == -1) {
    121           // Blocking error codes (EAGAIN, EWOULDBLOCK)
    122           // are handled by TLS_WANT_POLLIN and TLS_WANT_POLLOUT.
    123           // Do not check for them here.
    124           goto cleanup;
    125         } else if (nread == 0) {
    126           goto cleanup;
    127         } else {
    128           // write proxy->normal
    129           pfd[0].fd = fd_norm;
    130           pfd[0].events = POLLOUT;
    131           bufp = buf;
    132           while (nread > 0) {
    133             nready = poll(pfd, 1, poll_timeout);
    134             if (nready == -1) {
    135               goto cleanup;
    136             } else if (nready == 0) {
    137               goto cleanup;
    138             } else {
    139               if (pfd[0].revents & pfd[0].events) {
    140                 nwritten = write(fd_norm, bufp, nread);
    141                 if (nwritten == -1) {
    142                   goto cleanup;
    143                 } else {
    144                   nread -= nwritten;
    145                   bufp += nwritten;
    146                 }
    147               } else if (pfd[0].revents != 0) {
    148                 // POLLERR, POLLHUP
    149                 goto cleanup;
    150               }
    151             }
    152           }
    153         }
    154       }
    155     } else if (pfds[0].revents != 0) {
    156       // POLLERR
    157       goto cleanup;
    158     }
    159 
    160     // normal->proxy->tls
    161     if (pfds[1].revents & (pfds[1].events | POLLHUP)) {
    162       // read normal->proxy
    163       while (1) {
    164         nread = read(fd_norm, buf, sizeof(buf));
    165         if (nread == -1) {
    166           if (errno == EAGAIN || errno == EWOULDBLOCK) {
    167             break;
    168           } else {
    169             goto cleanup;
    170           }
    171         } else if (nread == 0) {
    172           goto cleanup;
    173         } else {
    174           // write proxy->tls
    175           pfd[0].fd = fd_tls;
    176           pfd[0].events = POLLIN | POLLOUT;
    177           bufp = buf;
    178           while (nread > 0) {
    179             nready = poll(pfd, 1, poll_timeout);
    180             if (nready == -1) {
    181               goto cleanup;
    182             } else if (nready == 0) {
    183               goto cleanup;
    184             } else {
    185               if (pfd[0].revents & (pfd[0].events | POLLHUP)) {
    186                 nwritten = tls_write(tls_ctx, bufp, nread);
    187                 if (nwritten == TLS_WANT_POLLIN) {
    188                   pfd[0].events = POLLIN;
    189                 } else if (nwritten == TLS_WANT_POLLOUT) {
    190                   pfd[0].events = POLLOUT;
    191                 } else if (nwritten == -1) {
    192                   goto cleanup;
    193                 } else {
    194                   nread -= nwritten;
    195                   bufp += nwritten;
    196                 }
    197               } else if (pfd[0].revents != 0) {
    198                 // POLLERR
    199                 goto cleanup;
    200               }
    201             }
    202           }
    203         }
    204       }
    205     } else if (pfds[1].revents != 0) {
    206       // POLLERR
    207       goto cleanup;
    208     }
    209   }
    210 
    211 cleanup:
    212   if (tls_ctx != NULL) {
    213     while (1) {
    214       ret = tls_close(tls_ctx);
    215       if (
    216         ret == TLS_WANT_POLLIN ||
    217         ret == TLS_WANT_POLLOUT
    218       ) {
    219         continue;
    220       } else {
    221         break;
    222       }
    223     }
    224     tls_free(tls_ctx);
    225   }
    226   close(cfd);
    227   if (sfd != -1)
    228     close(sfd);
    229 }
    230 
    231 static void
    232 cleanup(void)
    233 {
    234   if (proxy_udsfile != NULL)
    235     sock_remove_uds(proxy_udsfile);
    236 }
    237 
    238 static void
    239 sigcleanup(int sig)
    240 {
    241   struct sigaction act;
    242 
    243   cleanup();
    244 
    245   act.sa_handler = SIG_DFL;
    246   sigemptyset(&act.sa_mask);
    247   act.sa_flags = 0;
    248   sigaction(sig, &act, NULL);
    249 
    250   raise(sig);
    251 }
    252 
    253 static void
    254 handle_termsignals(void (*func)(int))
    255 {
    256   struct sigaction act;
    257 
    258   act.sa_handler = func;
    259   sigemptyset(&act.sa_mask);
    260   act.sa_flags = 0;
    261   sigaction(SIGTERM, &act, NULL);
    262   sigaction(SIGINT, &act, NULL);
    263 }
    264 
    265 static void
    266 usage(void)
    267 {
    268   fprintf(
    269     stderr,
    270     "usage: %s -s/-S [-h host] [-p port] [-u file]"
    271     " [-H host] [-P port] [-U file]"
    272     " [-c cert] [-k key] [-C ca] [-t timeout] [-v]\n",
    273     argv0
    274   );
    275   exit(1);
    276 }
    277 
    278 int
    279 main(int argc, char **argv)
    280 {
    281   struct settings s = {
    282     .mode = MODE_NONE,
    283     .tls_ctx = NULL,
    284     .timeout = 30,
    285     .server_host = NULL,
    286     .server_port = NULL,
    287     .server_udsfile = NULL
    288   };
    289   char *tls_cert_file = NULL;
    290   char *tls_key_file = NULL;
    291   char *tls_ca_file = NULL;
    292   char *proxy_host = NULL;
    293   char *proxy_port = NULL;
    294   int opt;
    295   struct tls_config *config;
    296   int fd, cfd;
    297   struct sigaction act;
    298   struct sockaddr_storage addr;
    299   socklen_t addr_len;
    300 
    301   argv0 = argv[0];
    302 
    303   while ((opt = getopt(argc, argv, "sSh:p:u:H:P:U:c:k:C:t:v")) != -1) {
    304     switch (opt) {
    305     case 's':
    306       s.mode = MODE_TLS_CLIENT;
    307       break;
    308     case 'S':
    309       s.mode = MODE_TLS_SERVER;
    310       break;
    311     case 'c':
    312       tls_cert_file = optarg;
    313       break;
    314     case 'k':
    315       tls_key_file = optarg;
    316       break;
    317     case 'C':
    318       tls_ca_file = optarg;
    319       break;
    320     case 'h':
    321       proxy_host = optarg;
    322       break;
    323     case 'p':
    324       proxy_port = optarg;
    325       break;
    326     case 'u':
    327       proxy_udsfile = optarg;
    328       break;
    329     case 'H':
    330       s.server_host = optarg;
    331       break;
    332     case 'P':
    333       s.server_port = optarg;
    334       break;
    335     case 'U':
    336       s.server_udsfile = optarg;
    337       break;
    338     case 't':
    339       s.timeout = atoi(optarg);
    340       break;
    341     case 'v':
    342       puts(VERSION);
    343       exit(0);
    344       break;
    345     default:
    346       usage();
    347       break;
    348     }
    349   }
    350 
    351   if (s.mode == MODE_NONE)
    352     usage();
    353   // If accepting TLS connections, cert and private key files are required
    354   if (s.mode == MODE_TLS_CLIENT) {
    355     if (tls_cert_file == NULL || tls_key_file == NULL)
    356       usage();
    357   }
    358   // allow IPS or UDS proxy
    359   // UDS not allowed with TLS
    360   if (
    361     !((proxy_host != NULL && proxy_port != NULL && proxy_udsfile == NULL) ||
    362     (proxy_host == NULL && proxy_port == NULL && proxy_udsfile != NULL)) ||
    363     (s.mode == MODE_TLS_CLIENT && proxy_udsfile != NULL)
    364   )
    365     usage();
    366   // allow IPS or UDS server
    367   // UDS not allowed with TLS
    368   if (
    369     !((s.server_host != NULL && s.server_port != NULL && s.server_udsfile == NULL) ||
    370     (s.server_host == NULL && s.server_port == NULL && s.server_udsfile != NULL)) ||
    371     (s.mode == MODE_TLS_SERVER && s.server_udsfile != NULL)
    372   )
    373     usage();
    374 
    375   // setup tls
    376   if (s.mode == MODE_TLS_CLIENT) {
    377     if ((s.tls_ctx = tls_server()) == NULL)
    378       err("tls_server");
    379   } else {
    380     if ((s.tls_ctx = tls_client()) == NULL)
    381       err("tls_client");
    382   }
    383   if ((config = tls_config_new()) == NULL)
    384     err("tls_config_new");
    385   if (tls_cert_file != NULL) {
    386     if (tls_config_set_cert_file(config, tls_cert_file) == -1)
    387       err("tls_config_set_cert_file");
    388   }
    389   if (tls_key_file != NULL) {
    390     if (tls_config_set_key_file(config, tls_key_file) == -1)
    391       err("tls_config_set_key_file");
    392   }
    393   if (tls_ca_file != NULL) {
    394     if (tls_config_set_ca_file(config, tls_ca_file) == -1)
    395       err("tls_config_set_ca_file");
    396   }
    397   if (tls_configure(s.tls_ctx, config) == -1)
    398     err("tls_configure");
    399 
    400   handle_termsignals(sigcleanup);
    401 
    402   // reap children
    403   act.sa_handler = SIG_IGN;
    404   sigemptyset(&act.sa_mask);
    405   act.sa_flags = 0;
    406   sigaction(SIGCHLD, &act, NULL);
    407 
    408   // setup proxy socket
    409   fd = proxy_udsfile
    410     ? sock_proxy_uds(proxy_udsfile)
    411     : sock_proxy_ips(proxy_host, proxy_port);
    412 
    413   addr_len = sizeof(addr);
    414   while (1) {
    415     if ((cfd = accept(fd, (struct sockaddr *) &addr, &addr_len)) == -1) {
    416       warn("accept");
    417       continue;
    418     }
    419 
    420     switch (fork()) {
    421     case 0:
    422       // signal handlers were copied with fork
    423       // restore default signal handlers
    424       // so that only parent process performs cleanup
    425       handle_termsignals(SIG_DFL);
    426 
    427       serve(&s, cfd, &addr);
    428       exit(0);
    429       break;
    430     case -1:
    431       warn("fork");
    432       close(cfd);
    433       break;
    434     default:
    435       close(cfd);
    436       break;
    437     }
    438   }
    439 
    440   // unreachable
    441   return 1;
    442 }