1 /*
   2  * Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
   3  *
   4  * Licensed under the Apache License, Version 2.0 (the "License").
   5  * You may not use this file except in compliance with the License.
   6  * A copy of the License is located at
   7  *
   8  *  http://aws.amazon.com/apache2.0
   9  *
  10  * or in the "license" file accompanying this file. This file is distributed
  11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
  12  * express or implied. See the License for the specific language governing
  13  * permissions and limitations under the License.
  14  */
  15 
  16 #include <sys/types.h>
  17 #include <sys/socket.h>
  18 #include <sys/ioctl.h>
  19 #include <sys/filio.h>
  20 #include <sys/poll.h>
  21 #include <netdb.h>
  22 
  23 #include <stdlib.h>
  24 #include <unistd.h>
  25 #include <string.h>
  26 #include <stropts.h>
  27 #include <stdio.h>
  28 
  29 #include <errno.h>
  30 
  31 #include <s2n.h>
  32 
  33 int echo(struct s2n_connection *conn, int sockfd)
  34 {
  35     struct pollfd readers[2];
  36 
  37     readers[0].fd = sockfd;
  38     readers[0].events = POLLIN;
  39     readers[1].fd = STDIN_FILENO;
  40     readers[1].events = POLLIN;
  41 
  42     int more;
  43     do {
  44         if (s2n_negotiate(conn, &more) < 0) {
  45             fprintf(stderr, "Failed to negotiate: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
  46             exit(1);
  47         }
  48     } while (more);
  49 
  50     /* Now that we've negotiated, print some parameters */
  51     int client_hello_version;
  52     int client_protocol_version;
  53     int server_protocol_version;
  54     int actual_protocol_version;
  55 
  56     if ((client_hello_version = s2n_connection_get_client_hello_version(conn)) < 0) {
  57         fprintf(stderr, "Could not get client hello version\n");
  58         exit(1);
  59     }
  60     if ((client_protocol_version = s2n_connection_get_client_protocol_version(conn)) < 0) {
  61         fprintf(stderr, "Could not get client protocol version\n");
  62         exit(1);
  63     }
  64     if ((server_protocol_version = s2n_connection_get_server_protocol_version(conn)) < 0) {
  65         fprintf(stderr, "Could not get server protocol version\n");
  66         exit(1);
  67     }
  68     if ((actual_protocol_version = s2n_connection_get_actual_protocol_version(conn)) < 0) {
  69         fprintf(stderr, "Could not get actual protocol version\n");
  70         exit(1);
  71     }
  72     printf("Client hello version: %d\n", client_hello_version);
  73     printf("Client protocol version: %d\n", client_protocol_version);
  74     printf("Server protocol version: %d\n", server_protocol_version);
  75     printf("Actual protocol version: %d\n", actual_protocol_version);
  76 
  77     if (s2n_get_server_name(conn)) {
  78         printf("Server name: %s\n", s2n_get_server_name(conn));
  79     }
  80     if (s2n_get_application_protocol(conn)) {
  81         printf("Application protocol: %s\n",
  82                 s2n_get_application_protocol(conn));
  83     }
  84     uint32_t length;
  85     const uint8_t *status = s2n_connection_get_ocsp_response(conn, &length);
  86     if (status && length > 0) {
  87         fprintf(stderr, "OCSP response received, length %d\n", length);
  88     }
  89 
  90     printf("Cipher negotiated: %s\n", s2n_connection_get_cipher(conn));
  91 
  92     /* Act as a simple proxy between stdin and the SSL connection */
  93     while (poll(readers, 2, -1) > 0) {
  94         char buffer[10240];
  95         int bytes_read, bytes_written;
  96 
  97         if (readers[0].revents & POLLIN) {
  98             do {
  99                 bytes_read = s2n_recv(conn, buffer, 10240, &more);
 100                 if (bytes_read == 0) {
 101                     /* Connection has been closed */
 102                     s2n_connection_wipe(conn);
 103                     return 0;
 104                 }
 105                 if (bytes_read < 0) {
 106                     fprintf(stderr, "Error reading from connection: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
 107                     exit(1);
 108                 }
 109                 bytes_written = write(STDOUT_FILENO, buffer, bytes_read);
 110                 if (bytes_written <= 0) {
 111                     fprintf(stderr, "Error writing to stdout\n");
 112                     exit(1);
 113                 }
 114             } while (more);
 115         }
 116         if (readers[1].revents & POLLIN) {
 117             int bytes_available;
 118             if (ioctl(STDIN_FILENO, FIONREAD, &bytes_available) < 0) {
 119                 bytes_available = 1;
 120             }
 121             if (bytes_available > sizeof(buffer)) {
 122                 bytes_available = sizeof(buffer);
 123             }
 124 
 125             /* Read as many bytes as we think we can */
 126             bytes_read = read(STDIN_FILENO, buffer, bytes_available);
 127             if (bytes_read < 0) {
 128                 fprintf(stderr, "Error reading from stdin\n");
 129                 exit(1);
 130             }
 131             if (bytes_read == 0) {
 132                 /* Exit on EOF */
 133                 return 0;
 134             }
 135 
 136             char *buf_ptr = buffer;
 137             do {
 138                 bytes_written = s2n_send(conn, buf_ptr, bytes_available, &more);
 139                 if (bytes_written < 0) {
 140                     fprintf(stderr, "Error writing to connection: '%s'\n", s2n_strerror(s2n_errno, "EN"));
 141                     exit(1);
 142                 }
 143 
 144                 bytes_available -= bytes_written;
 145                 buf_ptr += bytes_written;
 146             } while (bytes_available || more);
 147         }
 148     }
 149 
 150     return 0;
 151 }