1 /*
   2  * Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
   3  *
   4  * Licensed under the Apache License, Version 2.0 (the "License").
   5  * You may not use this file except in compliance with the License.
   6  * A copy of the License is located at
   7  *
   8  *  http://aws.amazon.com/apache2.0
   9  *
  10  * or in the "license" file accompanying this file. This file is distributed
  11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
  12  * express or implied. See the License for the specific language governing
  13  * permissions and limitations under the License.
  14  */
  15 
  16 #include <sys/types.h>
  17 #include <sys/socket.h>
  18 #include <sys/ioctl.h>
  19 #include <sys/poll.h>
  20 #include <netdb.h>
  21 
  22 #include <stdlib.h>
  23 #include <unistd.h>
  24 #include <string.h>
  25 #include <stdio.h>
  26 
  27 #include <errno.h>
  28 
  29 #include <s2n.h>
  30 
  31 int echo(struct s2n_connection *conn, int sockfd)
  32 {
  33     struct pollfd readers[2];
  34 
  35     readers[0].fd = sockfd;
  36     readers[0].events = POLLIN;
  37     readers[1].fd = STDIN_FILENO;
  38     readers[1].events = POLLIN;
  39 
  40     int more;
  41     do {
  42         if (s2n_negotiate(conn, &more) < 0) {
  43             fprintf(stderr, "Failed to negotiate: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
  44             exit(1);
  45         }
  46     } while (more);
  47 
  48     /* Now that we've negotiated, print some parameters */
  49     int client_hello_version;
  50     int client_protocol_version;
  51     int server_protocol_version;
  52     int actual_protocol_version;
  53 
  54     if ((client_hello_version = s2n_connection_get_client_hello_version(conn)) < 0) {
  55         fprintf(stderr, "Could not get client hello version\n");
  56         exit(1);
  57     }
  58     if ((client_protocol_version = s2n_connection_get_client_protocol_version(conn)) < 0) {
  59         fprintf(stderr, "Could not get client protocol version\n");
  60         exit(1);
  61     }
  62     if ((server_protocol_version = s2n_connection_get_server_protocol_version(conn)) < 0) {
  63         fprintf(stderr, "Could not get server protocol version\n");
  64         exit(1);
  65     }
  66     if ((actual_protocol_version = s2n_connection_get_actual_protocol_version(conn)) < 0) {
  67         fprintf(stderr, "Could not get actual protocol version\n");
  68         exit(1);
  69     }
  70     printf("Client hello version: %d\n", client_hello_version);
  71     printf("Client protocol version: %d\n", client_protocol_version);
  72     printf("Server protocol version: %d\n", server_protocol_version);
  73     printf("Actual protocol version: %d\n", actual_protocol_version);
  74 
  75     if (s2n_get_server_name(conn)) {
  76         printf("Server name: %s\n", s2n_get_server_name(conn));
  77     }
  78     if (s2n_get_application_protocol(conn)) {
  79         printf("Application protocol: %s\n",
  80                 s2n_get_application_protocol(conn));
  81     }
  82     uint32_t length;
  83     const uint8_t *status = s2n_connection_get_ocsp_response(conn, &length);
  84     if (status && length > 0) {
  85         fprintf(stderr, "OCSP response received, length %d\n", length);
  86     }
  87 
  88     printf("Cipher negotiated: %s\n", s2n_connection_get_cipher(conn));
  89 
  90     /* Act as a simple proxy between stdin and the SSL connection */
  91     while (poll(readers, 2, -1) > 0) {
  92         char buffer[10240];
  93         int bytes_read, bytes_written;
  94 
  95         if (readers[0].revents & POLLIN) {
  96             do {
  97                 bytes_read = s2n_recv(conn, buffer, 10240, &more);
  98                 if (bytes_read == 0) {
  99                     /* Connection has been closed */
 100                     s2n_connection_wipe(conn);
 101                     return 0;
 102                 }
 103                 if (bytes_read < 0) {
 104                     fprintf(stderr, "Error reading from connection: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
 105                     exit(1);
 106                 }
 107                 bytes_written = write(STDOUT_FILENO, buffer, bytes_read);
 108                 if (bytes_written <= 0) {
 109                     fprintf(stderr, "Error writing to stdout\n");
 110                     exit(1);
 111                 }
 112             } while (more);
 113         }
 114         if (readers[1].revents & POLLIN) {
 115             int bytes_available;
 116             if (ioctl(STDIN_FILENO, FIONREAD, &bytes_available) < 0) {
 117                 bytes_available = 1;
 118             }
 119             if (bytes_available > sizeof(buffer)) {
 120                 bytes_available = sizeof(buffer);
 121             }
 122 
 123             /* Read as many bytes as we think we can */
 124             bytes_read = read(STDIN_FILENO, buffer, bytes_available);
 125             if (bytes_read < 0) {
 126                 fprintf(stderr, "Error reading from stdin\n");
 127                 exit(1);
 128             }
 129             if (bytes_read == 0) {
 130                 /* Exit on EOF */
 131                 return 0;
 132             }
 133 
 134             char *buf_ptr = buffer;
 135             do {
 136                 bytes_written = s2n_send(conn, buf_ptr, bytes_available, &more);
 137                 if (bytes_written < 0) {
 138                     fprintf(stderr, "Error writing to connection: '%s'\n", s2n_strerror(s2n_errno, "EN"));
 139                     exit(1);
 140                 }
 141 
 142                 bytes_available -= bytes_written;
 143                 buf_ptr += bytes_written;
 144             } while (bytes_available || more);
 145         }
 146     }
 147 
 148     return 0;
 149 }