commit 2a9aa77706f96e965894bd184d29ad6eae6b69ad
parent d154c1068b6e353e2e542b92a03d2e753fe7b27d
Author: Nihal Jere <nihal@nihaljere.xyz>
Date: Sun, 3 May 2020 20:05:07 -0500
use TLS_POLL(IN/OUT)
Diffstat:
M | Makefile | | | 2 | +- |
M | tlsrp.c | | | 59 | +++++++++++++++++++++++++++++------------------------------ |
2 files changed, 30 insertions(+), 31 deletions(-)
diff --git a/Makefile b/Makefile
@@ -16,4 +16,4 @@ clean:
rm $(OBJ)
run:
- LD_LIBRARY_PATH=/usr/lib/libressl ./$(OBJ) -U "/tmp/conn.socket" -f 443
+ LD_LIBRARY_PATH=/usr/lib/libressl ./$(OBJ) -U "/tmp/conn.socket" -f 443 -a "/home/nihal/projects/libtls/CA/root.pem" -r "/home/nihal/projects/libtls/CA/server.crt" -k "/home/nihal/projects/libtls/CA/server.key"
diff --git a/tlsrp.c b/tlsrp.c
@@ -117,29 +117,6 @@ donetworkconnect(const char* host, const char* port)
return sfd;
}
-static void dowrite(int fd, char* buf, size_t towrite) {
- ssize_t written = 0;
- while (towrite > 0) {
- written = write(fd, buf, towrite);
- if (written == -1)
- die("failed to write:");
- towrite -= written;
- buf += written;
- }
-}
-
-static void dotlswrite(struct tls *tlss, char* buf, size_t towrite) {
- ssize_t written = 0;
- while (towrite > 0) {
- written = tls_write(tlss, buf, towrite);
- if (written == -1)
- die("failed to write:");
- towrite -= written;
- buf += written;
- }
-}
-
-// TODO use TLS_WANT_POLL(IN/OUT) instead of normal ones
static int
serve(int serverfd, int clientfd, struct tls *clientconn)
{
@@ -151,36 +128,58 @@ serve(int serverfd, int clientfd, struct tls *clientconn)
char clibuf[BUF_SIZE] = {0};
char serbuf[BUF_SIZE] = {0};
+ char *cliptr = NULL, *serptr = NULL;
+
size_t clicount = 0, sercount = 0;
+ ssize_t written = 0;
while (poll(pfd, 2, TIMEOUT) != 0) {
if ((pfd[CLIENT].revents | pfd[SERVER].revents) & POLLNVAL)
return -1;
- if ((pfd[CLIENT].revents & POLLIN)) {
+ if ((pfd[CLIENT].revents & POLLIN) && clicount == 0) {
clicount = tls_read(clientconn, clibuf, BUF_SIZE);
if (clicount == -1) {
die("client read failed: %s\n", tls_error(clientconn));
return -2;
+ } else if (clicount == TLS_WANT_POLLIN) {
+ pfd[CLIENT].events = POLLIN;
+ } else if (clicount == TLS_WANT_POLLOUT) {
+ pfd[CLIENT].events = POLLOUT;
+ } else {
+ cliptr = clibuf;
}
}
- if ((pfd[SERVER].revents & POLLIN)) {
+ if ((pfd[SERVER].revents & POLLIN) && sercount == 0) {
sercount = read(serverfd, serbuf, BUF_SIZE);
if (sercount == -1) {
die("server read failed:");
return -3;
}
+ serptr = serbuf;
}
if ((pfd[SERVER].revents & POLLOUT) && clicount > 0) {
- dowrite(serverfd, clibuf, clicount);
- clicount = 0;
+ written = write(serverfd, cliptr, clicount);
+ if (written == -1)
+ die("failed to write:");
+ clicount -= written;
+ cliptr += written;
}
if ((pfd[CLIENT].revents & POLLOUT) && sercount > 0) {
- dotlswrite(clientconn, serbuf, sercount);
- sercount = 0;
+ written = tls_write(clientconn, serptr, sercount);
+ if (written == -1)
+ die("failed tls_write: %s\n", tls_error(clientconn));
+ else if (written == TLS_WANT_POLLIN) {
+ pfd[CLIENT].events = POLLIN;
+ } else if (written == TLS_WANT_POLLOUT) {
+ pfd[CLIENT].events = POLLOUT;
+ } else {
+ sercount -= written;
+ serptr += written;
+ }
}
if ((pfd[CLIENT].revents | pfd[SERVER].revents) & POLLHUP)
@@ -309,7 +308,7 @@ main(int argc, char* argv[])
serve(serverfd, clientfd, conn);
tls_close(conn);
-tlsfail:
+ tlsfail:
close(serverfd);
close(clientfd);
close(bindfd);