1 /*
2 * Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License").
5 * You may not use this file except in compliance with the License.
6 * A copy of the License is located at
7 *
8 * http://aws.amazon.com/apache2.0
9 *
10 * or in the "license" file accompanying this file. This file is distributed
11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12 * express or implied. See the License for the specific language governing
13 * permissions and limitations under the License.
14 */
15
16 #include <sys/types.h>
17 #include <sys/socket.h>
18 #include <sys/ioctl.h>
19 #include <sys/poll.h>
20 #include <netdb.h>
21
22 #include <stdlib.h>
23 #include <unistd.h>
24 #include <string.h>
25 #include <stdio.h>
26
27 #include <errno.h>
28
29 #include <s2n.h>
30
31 int echo(struct s2n_connection *conn, int sockfd)
32 {
33 struct pollfd readers[2];
34
35 readers[0].fd = sockfd;
36 readers[0].events = POLLIN;
37 readers[1].fd = STDIN_FILENO;
38 readers[1].events = POLLIN;
39
40 int more;
41 do {
42 if (s2n_negotiate(conn, &more) < 0) {
43 fprintf(stderr, "Failed to negotiate: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
44 exit(1);
45 }
46 } while (more);
47
48 /* Now that we've negotiated, print some parameters */
49 int client_hello_version;
50 int client_protocol_version;
51 int server_protocol_version;
52 int actual_protocol_version;
53
54 if ((client_hello_version = s2n_connection_get_client_hello_version(conn)) < 0) {
55 fprintf(stderr, "Could not get client hello version\n");
56 exit(1);
57 }
58 if ((client_protocol_version = s2n_connection_get_client_protocol_version(conn)) < 0) {
59 fprintf(stderr, "Could not get client protocol version\n");
60 exit(1);
61 }
62 if ((server_protocol_version = s2n_connection_get_server_protocol_version(conn)) < 0) {
63 fprintf(stderr, "Could not get server protocol version\n");
64 exit(1);
65 }
66 if ((actual_protocol_version = s2n_connection_get_actual_protocol_version(conn)) < 0) {
67 fprintf(stderr, "Could not get actual protocol version\n");
68 exit(1);
69 }
70 printf("Client hello version: %d\n", client_hello_version);
71 printf("Client protocol version: %d\n", client_protocol_version);
72 printf("Server protocol version: %d\n", server_protocol_version);
73 printf("Actual protocol version: %d\n", actual_protocol_version);
74
75 if (s2n_get_server_name(conn)) {
76 printf("Server name: %s\n", s2n_get_server_name(conn));
77 }
78 if (s2n_get_application_protocol(conn)) {
79 printf("Application protocol: %s\n",
80 s2n_get_application_protocol(conn));
81 }
82 uint32_t length;
83 const uint8_t *status = s2n_connection_get_ocsp_response(conn, &length);
84 if (status && length > 0) {
85 fprintf(stderr, "OCSP response received, length %d\n", length);
86 }
87
88 printf("Cipher negotiated: %s\n", s2n_connection_get_cipher(conn));
89
90 /* Act as a simple proxy between stdin and the SSL connection */
91 while (poll(readers, 2, -1) > 0) {
92 char buffer[10240];
93 int bytes_read, bytes_written;
94
95 if (readers[0].revents & POLLIN) {
96 do {
97 bytes_read = s2n_recv(conn, buffer, 10240, &more);
98 if (bytes_read == 0) {
99 /* Connection has been closed */
100 s2n_connection_wipe(conn);
101 return 0;
102 }
103 if (bytes_read < 0) {
104 fprintf(stderr, "Error reading from connection: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
105 exit(1);
106 }
107 bytes_written = write(STDOUT_FILENO, buffer, bytes_read);
108 if (bytes_written <= 0) {
109 fprintf(stderr, "Error writing to stdout\n");
110 exit(1);
111 }
112 } while (more);
113 }
114 if (readers[1].revents & POLLIN) {
115 int bytes_available;
116 if (ioctl(STDIN_FILENO, FIONREAD, &bytes_available) < 0) {
117 bytes_available = 1;
118 }
119 if (bytes_available > sizeof(buffer)) {
120 bytes_available = sizeof(buffer);
121 }
122
123 /* Read as many bytes as we think we can */
124 bytes_read = read(STDIN_FILENO, buffer, bytes_available);
125 if (bytes_read < 0) {
126 fprintf(stderr, "Error reading from stdin\n");
127 exit(1);
128 }
129 if (bytes_read == 0) {
130 /* Exit on EOF */
131 return 0;
132 }
133
134 char *buf_ptr = buffer;
135 do {
136 bytes_written = s2n_send(conn, buf_ptr, bytes_available, &more);
137 if (bytes_written < 0) {
138 fprintf(stderr, "Error writing to connection: '%s'\n", s2n_strerror(s2n_errno, "EN"));
139 exit(1);
140 }
141
142 bytes_available -= bytes_written;
143 buf_ptr += bytes_written;
144 } while (bytes_available || more);
145 }
146 }
147
148 return 0;
149 }