Botan 3.0.0-alpha0
Crypto and TLS for C&
tls_handshake_io.cpp
Go to the documentation of this file.
1/*
2* TLS Handshake IO
3* (C) 2012,2014,2015 Jack Lloyd
4*
5* Botan is released under the Simplified BSD License (see license.txt)
6*/
7
8#include <botan/internal/tls_handshake_io.h>
9#include <botan/internal/tls_record.h>
10#include <botan/internal/tls_seq_numbers.h>
11#include <botan/tls_messages.h>
12#include <botan/exceptn.h>
13#include <botan/internal/loadstor.h>
14#include <chrono>
15
16namespace Botan::TLS {
17
18namespace {
19
20inline size_t load_be24(const uint8_t q[3])
21 {
22 return make_uint32(0,
23 q[0],
24 q[1],
25 q[2]);
26 }
27
28void store_be24(uint8_t out[3], size_t val)
29 {
30 out[0] = get_byte<1>(static_cast<uint32_t>(val));
31 out[1] = get_byte<2>(static_cast<uint32_t>(val));
32 out[2] = get_byte<3>(static_cast<uint32_t>(val));
33 }
34
35uint64_t steady_clock_ms()
36 {
37 return std::chrono::duration_cast<std::chrono::milliseconds>(
38 std::chrono::steady_clock::now().time_since_epoch()).count();
39 }
40
41}
42
44 {
46 }
47
48void Stream_Handshake_IO::add_record(const uint8_t record[],
49 size_t record_len,
50 Record_Type record_type, uint64_t /*sequence_number*/)
51 {
52 if(record_type == HANDSHAKE)
53 {
54 m_queue.insert(m_queue.end(), record, record + record_len);
55 }
56 else if(record_type == CHANGE_CIPHER_SPEC)
57 {
58 if(record_len != 1 || record[0] != 1)
59 throw Decoding_Error("Invalid ChangeCipherSpec");
60
61 // Pretend it's a regular handshake message of zero length
62 const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
63 m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
64 }
65 else
66 throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
67 }
68
69std::pair<Handshake_Type, std::vector<uint8_t>>
71 {
72 if(m_queue.size() >= 4)
73 {
74 const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
75
76 if(m_queue.size() >= length)
77 {
78 Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
79
80 if(type == HANDSHAKE_NONE)
81 throw Decoding_Error("Invalid handshake message type");
82
83 std::vector<uint8_t> contents(m_queue.begin() + 4,
84 m_queue.begin() + length);
85
86 m_queue.erase(m_queue.begin(), m_queue.begin() + length);
87
88 return std::make_pair(type, contents);
89 }
90 }
91
92 return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
93 }
94
95std::vector<uint8_t>
96Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
97 Handshake_Type type) const
98 {
99 std::vector<uint8_t> send_buf(4 + msg.size());
100
101 const size_t buf_size = msg.size();
102
103 send_buf[0] = static_cast<uint8_t>(type);
104
105 store_be24(&send_buf[1], buf_size);
106
107 if (!msg.empty())
108 {
109 copy_mem(&send_buf[4], msg.data(), msg.size());
110 }
111
112 return send_buf;
113 }
114
115std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
116 {
117 throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
118 }
119
120std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
121 {
122 const std::vector<uint8_t> msg_bits = msg.serialize();
123
124 if(msg.type() == HANDSHAKE_CCS)
125 {
126 m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
127 return std::vector<uint8_t>(); // not included in handshake hashes
128 }
129
130 auto buf = format(msg_bits, msg.wire_type());
131 m_send_hs(HANDSHAKE, buf);
132 return buf;
133 }
134
136 {
138 }
139
140void Datagram_Handshake_IO::retransmit_last_flight()
141 {
142 const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
143 retransmit_flight(flight_idx);
144 }
145
146void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
147 {
148 const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
149
150 BOTAN_ASSERT(!flight.empty(), "Nonempty flight to retransmit");
151
152 uint16_t epoch = m_flight_data[flight[0]].epoch;
153
154 for(auto msg_seq : flight)
155 {
156 auto& msg = m_flight_data[msg_seq];
157
158 if(msg.epoch != epoch)
159 {
160 // Epoch gap: insert the CCS
161 std::vector<uint8_t> ccs(1, 1);
162 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
163 }
164
165 send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
166 epoch = msg.epoch;
167 }
168 }
169
171 {
172 return false;
173 }
174
176 {
177 if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
178 {
179 /*
180 If we haven't written anything yet obviously no timeout.
181 Also no timeout possible if we are mid-flight,
182 */
183 return false;
184 }
185
186 const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
187
188 if(ms_since_write < m_next_timeout)
189 return false;
190
191 retransmit_last_flight();
192
193 m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
194 return true;
195 }
196
197void Datagram_Handshake_IO::add_record(const uint8_t record[],
198 size_t record_len,
199 Record_Type record_type,
200 uint64_t record_sequence)
201 {
202 const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
203
204 if(record_type == CHANGE_CIPHER_SPEC)
205 {
206 if(record_len != 1 || record[0] != 1)
207 throw Decoding_Error("Invalid ChangeCipherSpec");
208
209 // TODO: check this is otherwise empty
210 m_ccs_epochs.insert(epoch);
211 return;
212 }
213
214 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
215
216 while(record_len)
217 {
218 if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
219 return; // completely bogus? at least degenerate/weird
220
221 const uint8_t msg_type = record[0];
222 const size_t msg_len = load_be24(&record[1]);
223 const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
224 const size_t fragment_offset = load_be24(&record[6]);
225 const size_t fragment_length = load_be24(&record[9]);
226
227 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
228
229 if(record_len < total_size)
230 throw Decoding_Error("Bad lengths in DTLS header");
231
232 if(message_seq >= m_in_message_seq)
233 {
234 m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
236 fragment_offset,
237 epoch,
238 msg_type,
239 msg_len);
240 }
241 else
242 {
243 // TODO: detect retransmitted flight
244 }
245
246 record += total_size;
247 record_len -= total_size;
248 }
249 }
250
251std::pair<Handshake_Type, std::vector<uint8_t>>
253 {
254 // Expecting a message means the last flight is concluded
255 if(!m_flights.rbegin()->empty())
256 m_flights.push_back(std::vector<uint16_t>());
257
258 if(expecting_ccs)
259 {
260 if(!m_messages.empty())
261 {
262 const uint16_t current_epoch = m_messages.begin()->second.epoch();
263
264 if(m_ccs_epochs.count(current_epoch))
265 return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
266 }
267 return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
268 }
269
270 auto i = m_messages.find(m_in_message_seq);
271
272 if(i == m_messages.end() || !i->second.complete())
273 {
274 return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
275 }
276
277 m_in_message_seq += 1;
278
279 return i->second.message();
280 }
281
282void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
283 const uint8_t fragment[],
284 size_t fragment_length,
285 size_t fragment_offset,
286 uint16_t epoch,
287 uint8_t msg_type,
288 size_t msg_length)
289 {
290 if(complete())
291 return; // already have entire message, ignore this
292
293 if(m_msg_type == HANDSHAKE_NONE)
294 {
295 m_epoch = epoch;
296 m_msg_type = msg_type;
297 m_msg_length = msg_length;
298 }
299
300 if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
301 throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
302
303 if(fragment_offset > m_msg_length)
304 throw Decoding_Error("Fragment offset past end of message");
305
306 if(fragment_offset + fragment_length > m_msg_length)
307 throw Decoding_Error("Fragment overlaps past end of message");
308
309 if(fragment_offset == 0 && fragment_length == m_msg_length)
310 {
311 m_fragments.clear();
312 m_message.assign(fragment, fragment+fragment_length);
313 }
314 else
315 {
316 /*
317 * FIXME. This is a pretty lame way to do defragmentation, huge
318 * overhead with a tree node per byte.
319 *
320 * Also should confirm that all overlaps have no changes,
321 * otherwise we expose ourselves to the classic fingerprinting
322 * and IDS evasion attacks on IP fragmentation.
323 */
324 for(size_t i = 0; i != fragment_length; ++i)
325 m_fragments[fragment_offset+i] = fragment[i];
326
327 if(m_fragments.size() == m_msg_length)
328 {
329 m_message.resize(m_msg_length);
330 for(size_t i = 0; i != m_msg_length; ++i)
331 m_message[i] = m_fragments[i];
332 m_fragments.clear();
333 }
334 }
335 }
336
337bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
338 {
339 return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
340 }
341
342std::pair<Handshake_Type, std::vector<uint8_t>>
343Datagram_Handshake_IO::Handshake_Reassembly::message() const
344 {
345 if(!complete())
346 throw Internal_Error("Datagram_Handshake_IO - message not complete");
347
348 return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
349 }
350
351std::vector<uint8_t>
352Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
353 size_t frag_len,
354 uint16_t frag_offset,
355 uint16_t msg_len,
357 uint16_t msg_sequence) const
358 {
359 std::vector<uint8_t> send_buf(12 + frag_len);
360
361 send_buf[0] = static_cast<uint8_t>(type);
362
363 store_be24(&send_buf[1], msg_len);
364
365 store_be(msg_sequence, &send_buf[4]);
366
367 store_be24(&send_buf[6], frag_offset);
368 store_be24(&send_buf[9], frag_len);
369
370 if (frag_len > 0)
371 {
372 copy_mem(&send_buf[12], fragment, frag_len);
373 }
374
375 return send_buf;
376 }
377
378std::vector<uint8_t>
379Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
381 uint16_t msg_sequence) const
382 {
383 return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
384 }
385
386std::vector<uint8_t>
387Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
388 Handshake_Type type) const
389 {
390 return format_w_seq(msg, type, m_in_message_seq - 1);
391 }
392
393std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
394 {
395 return this->send_under_epoch(msg, m_seqs.current_write_epoch());
396 }
397
398std::vector<uint8_t>
400 {
401 const std::vector<uint8_t> msg_bits = msg.serialize();
402 const Handshake_Type msg_type = msg.type();
403
404 if(msg_type == HANDSHAKE_CCS)
405 {
406 m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
407 return std::vector<uint8_t>(); // not included in handshake hashes
408 }
409 else if(msg_type == HELLO_VERIFY_REQUEST)
410 {
411 // This message is not included in the handshake hashes
412 send_message(m_out_message_seq, epoch, msg_type, msg_bits);
413 m_out_message_seq += 1;
414 return std::vector<uint8_t>();
415 }
416
417 // Note: not saving CCS, instead we know it was there due to change in epoch
418 m_flights.rbegin()->push_back(m_out_message_seq);
419 m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
420
421 m_out_message_seq += 1;
422 m_last_write = steady_clock_ms();
423 m_next_timeout = m_initial_timeout;
424
425 return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
426 }
427
428std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
429 uint16_t epoch,
430 Handshake_Type msg_type,
431 const std::vector<uint8_t>& msg_bits)
432 {
433 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
434
435 auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq);
436
437 if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
438 {
439 m_send_hs(epoch, HANDSHAKE, no_fragment);
440 }
441 else
442 {
443 size_t frag_offset = 0;
444
445 /**
446 * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
447 * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
448 * over-estimate here. When CBC ciphers are removed this can be reduced
449 * since AEAD modes have no padding, at most 16 byte mac, and smaller
450 * per-record nonce.
451 */
452 const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
453 const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
454
455 if(m_mtu <= (header_overhead + ciphersuite_overhead))
456 throw Invalid_Argument("DTLS MTU is too small to send headers");
457
458 const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
459
460 while(frag_offset != msg_bits.size())
461 {
462 const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
463
464 const std::vector<uint8_t> frag =
465 format_fragment(&msg_bits[frag_offset],
466 frag_len,
467 static_cast<uint16_t>(frag_offset),
468 static_cast<uint16_t>(msg_bits.size()),
469 msg_type,
470 msg_seq);
471
472 m_send_hs(epoch, HANDSHAKE, frag);
473
474 frag_offset += frag_len;
475 }
476 }
477
478 return no_fragment;
479 }
480
481}
#define BOTAN_ASSERT(expr, assertion_made)
Definition: assert.h:54
virtual uint16_t current_write_epoch() const =0
std::vector< uint8_t > send_under_epoch(const Handshake_Message &msg, uint16_t epoch) override
void add_record(const uint8_t record[], size_t record_len, Record_Type type, uint64_t sequence_number) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
Protocol_Version initial_record_version() const override
std::vector< uint8_t > send(const Handshake_Message &msg) override
virtual Handshake_Type type() const =0
virtual std::vector< uint8_t > serialize() const =0
virtual Handshake_Type wire_type() const
std::vector< uint8_t > send_under_epoch(const Handshake_Message &msg, uint16_t epoch) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
Protocol_Version initial_record_version() const override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::vector< uint8_t > send(const Handshake_Message &msg) override
void add_record(const uint8_t record[], size_t record_len, Record_Type type, uint64_t sequence_number) override
std::string to_string(const BER_Object &obj)
Definition: asn1_obj.cpp:209
@ HELLO_VERIFY_REQUEST
Definition: tls_magic.h:67
@ HANDSHAKE_CCS
Definition: tls_magic.h:87
@ HANDSHAKE_NONE
Definition: tls_magic.h:88
@ CHANGE_CIPHER_SPEC
Definition: tls_magic.h:52
@ DTLS_HEADER_SIZE
Definition: tls_magic.h:28
constexpr void copy_mem(T *out, const T *in, size_t n)
Definition: mem_ops.h:126
constexpr uint32_t make_uint32(uint8_t i0, uint8_t i1, uint8_t i2, uint8_t i3)
Definition: loadstor.h:78
constexpr uint16_t load_be< uint16_t >(const uint8_t in[], size_t off)
Definition: loadstor.h:150
constexpr void store_be(uint16_t in, uint8_t out[2])
Definition: loadstor.h:449
MechanismType type
uint16_t fragment_length