1 /*
2 * Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License").
5 * You may not use this file except in compliance with the License.
6 * A copy of the License is located at
7 *
8 * http://aws.amazon.com/apache2.0
9 *
10 * or in the "license" file accompanying this file. This file is distributed
11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12 * express or implied. See the License for the specific language governing
13 * permissions and limitations under the License.
14 */
15
16 #include <sys/types.h>
17 #include <sys/socket.h>
18 #include <sys/ioctl.h>
19 #include <sys/filio.h>
20 #include <sys/poll.h>
21 #include <netdb.h>
22
23 #include <stdlib.h>
24 #include <unistd.h>
25 #include <string.h>
26 #include <stropts.h>
27 #include <stdio.h>
28
29 #include <errno.h>
30
31 #include <s2n.h>
32
33 int echo(struct s2n_connection *conn, int sockfd)
34 {
35 struct pollfd readers[2];
36
37 readers[0].fd = sockfd;
38 readers[0].events = POLLIN;
39 readers[1].fd = STDIN_FILENO;
40 readers[1].events = POLLIN;
41
42 int more;
43 do {
44 if (s2n_negotiate(conn, &more) < 0) {
45 fprintf(stderr, "Failed to negotiate: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
46 exit(1);
47 }
48 } while (more);
49
50 /* Now that we've negotiated, print some parameters */
51 int client_hello_version;
52 int client_protocol_version;
53 int server_protocol_version;
54 int actual_protocol_version;
55
56 if ((client_hello_version = s2n_connection_get_client_hello_version(conn)) < 0) {
57 fprintf(stderr, "Could not get client hello version\n");
58 exit(1);
59 }
60 if ((client_protocol_version = s2n_connection_get_client_protocol_version(conn)) < 0) {
61 fprintf(stderr, "Could not get client protocol version\n");
62 exit(1);
63 }
64 if ((server_protocol_version = s2n_connection_get_server_protocol_version(conn)) < 0) {
65 fprintf(stderr, "Could not get server protocol version\n");
66 exit(1);
67 }
68 if ((actual_protocol_version = s2n_connection_get_actual_protocol_version(conn)) < 0) {
69 fprintf(stderr, "Could not get actual protocol version\n");
70 exit(1);
71 }
72 printf("Client hello version: %d\n", client_hello_version);
73 printf("Client protocol version: %d\n", client_protocol_version);
74 printf("Server protocol version: %d\n", server_protocol_version);
75 printf("Actual protocol version: %d\n", actual_protocol_version);
76
77 if (s2n_get_server_name(conn)) {
78 printf("Server name: %s\n", s2n_get_server_name(conn));
79 }
80 if (s2n_get_application_protocol(conn)) {
81 printf("Application protocol: %s\n",
82 s2n_get_application_protocol(conn));
83 }
84 uint32_t length;
85 const uint8_t *status = s2n_connection_get_ocsp_response(conn, &length);
86 if (status && length > 0) {
87 fprintf(stderr, "OCSP response received, length %d\n", length);
88 }
89
90 printf("Cipher negotiated: %s\n", s2n_connection_get_cipher(conn));
91
92 /* Act as a simple proxy between stdin and the SSL connection */
93 while (poll(readers, 2, -1) > 0) {
94 char buffer[10240];
95 int bytes_read, bytes_written;
96
97 if (readers[0].revents & POLLIN) {
98 do {
99 bytes_read = s2n_recv(conn, buffer, 10240, &more);
100 if (bytes_read == 0) {
101 /* Connection has been closed */
102 s2n_connection_wipe(conn);
103 return 0;
104 }
105 if (bytes_read < 0) {
106 fprintf(stderr, "Error reading from connection: '%s' %d\n", s2n_strerror(s2n_errno, "EN"), s2n_connection_get_alert(conn));
107 exit(1);
108 }
109 bytes_written = write(STDOUT_FILENO, buffer, bytes_read);
110 if (bytes_written <= 0) {
111 fprintf(stderr, "Error writing to stdout\n");
112 exit(1);
113 }
114 } while (more);
115 }
116 if (readers[1].revents & POLLIN) {
117 int bytes_available;
118 if (ioctl(STDIN_FILENO, FIONREAD, &bytes_available) < 0) {
119 bytes_available = 1;
120 }
121 if (bytes_available > sizeof(buffer)) {
122 bytes_available = sizeof(buffer);
123 }
124
125 /* Read as many bytes as we think we can */
126 bytes_read = read(STDIN_FILENO, buffer, bytes_available);
127 if (bytes_read < 0) {
128 fprintf(stderr, "Error reading from stdin\n");
129 exit(1);
130 }
131 if (bytes_read == 0) {
132 /* Exit on EOF */
133 return 0;
134 }
135
136 char *buf_ptr = buffer;
137 do {
138 bytes_written = s2n_send(conn, buf_ptr, bytes_available, &more);
139 if (bytes_written < 0) {
140 fprintf(stderr, "Error writing to connection: '%s'\n", s2n_strerror(s2n_errno, "EN"));
141 exit(1);
142 }
143
144 bytes_available -= bytes_written;
145 buf_ptr += bytes_written;
146 } while (bytes_available || more);
147 }
148 }
149
150 return 0;
151 }