Botan 3.11.1
Crypto and TLS for C&
tls_handshake_io.cpp
Go to the documentation of this file.
1/*
2* TLS Handshake IO
3* (C) 2012,2014,2015 Jack Lloyd
4*
5* Botan is released under the Simplified BSD License (see license.txt)
6*/
7
8#include <botan/internal/tls_handshake_io.h>
9
10#include <botan/exceptn.h>
11#include <botan/tls_exceptn.h>
12#include <botan/tls_handshake_msg.h>
13#include <botan/internal/loadstor.h>
14#include <botan/internal/tls_record.h>
15#include <botan/internal/tls_seq_numbers.h>
16#include <chrono>
17
18namespace Botan::TLS {
19
20namespace {
21
22inline size_t load_be24(const uint8_t q[3]) {
23 return make_uint32(0, q[0], q[1], q[2]);
24}
25
26void store_be24(uint8_t out[3], size_t val) {
27 out[0] = get_byte<1>(static_cast<uint32_t>(val));
28 out[1] = get_byte<2>(static_cast<uint32_t>(val));
29 out[2] = get_byte<3>(static_cast<uint32_t>(val));
30}
31
32uint64_t steady_clock_ms() {
33 return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch())
34 .count();
35}
36
37} // namespace
38
40 return Protocol_Version::TLS_V12;
41}
42
43void Stream_Handshake_IO::add_record(const uint8_t record[],
44 size_t record_len,
45 Record_Type record_type,
46 uint64_t /*sequence_number*/) {
47 if(record_type == Record_Type::Handshake) {
48 m_queue.insert(m_queue.end(), record, record + record_len);
49 } else if(record_type == Record_Type::ChangeCipherSpec) {
50 if(record_len != 1 || record[0] != 1) {
51 throw Decoding_Error("Invalid ChangeCipherSpec");
52 }
53
54 // Pretend it's a regular handshake message of zero length
55 const uint8_t ccs_hs[] = {static_cast<uint8_t>(Handshake_Type::HandshakeCCS), 0, 0, 0};
56 m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
57 } else {
58 throw Decoding_Error("Unknown message type " + std::to_string(static_cast<size_t>(record_type)) +
59 " in handshake processing");
60 }
61}
62
63std::pair<Handshake_Type, std::vector<uint8_t>> Stream_Handshake_IO::get_next_record(bool expecting_ccs) {
64 if(m_queue.size() >= 4) {
65 const Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
66
67 if(type == Handshake_Type::None) {
68 throw Decoding_Error("Invalid handshake message type");
69 }
70
71 const size_t rec_length = make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
72
73 // If we are expecting a CCS but the next queued message is not a CCS,
74 // the peer has skipped the CCS message. This can happen when the peer
75 // sends an encrypted Finished without the preceding CCS, in which case
76 // the encrypted bytes are misinterpreted as a handshake message.
77 if(expecting_ccs) {
78 const bool is_ccs = (type == Handshake_Type::HandshakeCCS && rec_length == 0);
79 if(!is_ccs) {
80 throw TLS_Exception(Alert::UnexpectedMessage, "Expected ChangeCipherSpec but got a handshake message");
81 }
82 }
83
84 const size_t length = 4 + rec_length;
85
86 if(m_queue.size() >= length) {
87 const std::vector<uint8_t> contents(m_queue.begin() + 4, m_queue.begin() + length);
88
89 m_queue.erase(m_queue.begin(), m_queue.begin() + length);
90
91 return std::make_pair(type, contents);
92 }
93 }
94
95 return std::make_pair(Handshake_Type::None, std::vector<uint8_t>());
96}
97
98std::vector<uint8_t> Stream_Handshake_IO::format(const std::vector<uint8_t>& msg, Handshake_Type type) const {
99 std::vector<uint8_t> send_buf(4 + msg.size());
100
101 const size_t buf_size = msg.size();
102
103 send_buf[0] = static_cast<uint8_t>(type);
104
105 store_be24(&send_buf[1], buf_size);
106
107 if(!msg.empty()) {
108 copy_mem(&send_buf[4], msg.data(), msg.size());
109 }
110
111 return send_buf;
112}
113
114std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/) {
115 throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
116}
117
118std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg) {
119 const std::vector<uint8_t> msg_bits = msg.serialize();
120
122 m_send_hs(Record_Type::ChangeCipherSpec, msg_bits);
123 return std::vector<uint8_t>(); // not included in handshake hashes
124 }
125
126 auto buf = format(msg_bits, msg.wire_type());
127 m_send_hs(Record_Type::Handshake, buf);
128 return buf;
129}
130
132 return Protocol_Version::DTLS_V12;
133}
134
135void Datagram_Handshake_IO::retransmit_last_flight() {
136 const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
137 retransmit_flight(flight_idx);
138}
139
140void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx) {
141 const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
142
143 BOTAN_ASSERT(!flight.empty(), "Nonempty flight to retransmit");
144
145 uint16_t epoch = m_flight_data[flight[0]].epoch;
146
147 for(auto msg_seq : flight) {
148 auto& msg = m_flight_data[msg_seq];
149
150 if(msg.epoch != epoch) {
151 // Epoch gap: insert the CCS
152 const std::vector<uint8_t> ccs(1, 1);
153 m_send_hs(epoch, Record_Type::ChangeCipherSpec, ccs);
154 }
155
156 send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
157 epoch = msg.epoch;
158 }
159}
160
162 return false;
163}
164
166 if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) {
167 /*
168 If we haven't written anything yet obviously no timeout.
169 Also no timeout possible if we are mid-flight,
170 */
171 return false;
172 }
173
174 const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
175
176 if(ms_since_write < m_next_timeout) {
177 return false;
178 }
179
180 retransmit_last_flight();
181
182 m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
183 return true;
184}
185
186void Datagram_Handshake_IO::add_record(const uint8_t record[],
187 size_t record_len,
188 Record_Type record_type,
189 uint64_t record_sequence) {
190 const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
191
192 if(record_type == Record_Type::ChangeCipherSpec) {
193 if(record_len != 1 || record[0] != 1) {
194 throw Decoding_Error("Invalid ChangeCipherSpec");
195 }
196
197 // TODO: check this is otherwise empty
198 m_ccs_epochs.insert(epoch);
199 return;
200 }
201
202 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
203
204 while(record_len > 0) {
205 if(record_len < DTLS_HANDSHAKE_HEADER_LEN) {
206 return; // completely bogus? at least degenerate/weird
207 }
208
209 const Handshake_Type msg_type = static_cast<Handshake_Type>(record[0]);
210 const size_t msg_len = load_be24(&record[1]);
211 const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
212 const size_t fragment_offset = load_be24(&record[6]);
213 const size_t fragment_length = load_be24(&record[9]);
214
215 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
216
217 if(record_len < total_size) {
218 throw Decoding_Error("Bad lengths in DTLS header");
219 }
220
221 if(message_seq >= m_in_message_seq) {
222 m_messages[message_seq].add_fragment(
223 &record[DTLS_HANDSHAKE_HEADER_LEN], fragment_length, fragment_offset, epoch, msg_type, msg_len);
224 } else {
225 // TODO: detect retransmitted flight
226 }
227
228 record += total_size;
229 record_len -= total_size;
230 }
231}
232
233std::pair<Handshake_Type, std::vector<uint8_t>> Datagram_Handshake_IO::get_next_record(bool expecting_ccs) {
234 // Expecting a message means the last flight is concluded
235 if(!m_flights.rbegin()->empty()) {
236 m_flights.push_back(std::vector<uint16_t>());
237 }
238
239 if(expecting_ccs) {
240 if(!m_messages.empty()) {
241 const uint16_t current_epoch = m_messages.begin()->second.epoch();
242
243 if(m_ccs_epochs.contains(current_epoch)) {
244 return std::make_pair(Handshake_Type::HandshakeCCS, std::vector<uint8_t>());
245 }
246 }
247 return std::make_pair(Handshake_Type::None, std::vector<uint8_t>());
248 }
249
250 auto i = m_messages.find(m_in_message_seq);
251
252 if(i == m_messages.end() || !i->second.complete()) {
253 return std::make_pair(Handshake_Type::None, std::vector<uint8_t>());
254 }
255
256 m_in_message_seq += 1;
257
258 return i->second.message();
259}
260
261void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(const uint8_t fragment[],
262 size_t fragment_length,
263 size_t fragment_offset,
264 uint16_t epoch,
265 Handshake_Type msg_type,
266 size_t msg_length) {
267 if(complete()) {
268 return; // already have entire message, ignore this
269 }
270
271 if(m_msg_type == Handshake_Type::None) {
272 m_epoch = epoch;
273 m_msg_type = msg_type;
274 m_msg_length = msg_length;
275 }
276
277 if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) {
278 throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
279 }
280
281 if(fragment_offset > m_msg_length) {
282 throw Decoding_Error("Fragment offset past end of message");
283 }
284
285 if(fragment_offset + fragment_length > m_msg_length) {
286 throw Decoding_Error("Fragment overlaps past end of message");
287 }
288
289 if(fragment_offset == 0 && fragment_length == m_msg_length) {
290 m_fragments.clear();
291 m_message.assign(fragment, fragment + fragment_length);
292 } else {
293 /*
294 * FIXME. This is a pretty lame way to do defragmentation, huge
295 * overhead with a tree node per byte.
296 *
297 * Also should confirm that all overlaps have no changes,
298 * otherwise we expose ourselves to the classic fingerprinting
299 * and IDS evasion attacks on IP fragmentation.
300 */
301 for(size_t i = 0; i != fragment_length; ++i) {
302 m_fragments[fragment_offset + i] = fragment[i];
303 }
304
305 if(m_fragments.size() == m_msg_length) {
306 m_message.resize(m_msg_length);
307 for(size_t i = 0; i != m_msg_length; ++i) {
308 m_message[i] = m_fragments[i];
309 }
310 m_fragments.clear();
311 }
312 }
313}
314
315bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const {
316 return (m_msg_type != Handshake_Type::None && m_message.size() == m_msg_length);
317}
318
319std::pair<Handshake_Type, std::vector<uint8_t>> Datagram_Handshake_IO::Handshake_Reassembly::message() const {
320 if(!complete()) {
321 throw Internal_Error("Datagram_Handshake_IO - message not complete");
322 }
323
324 return std::make_pair(m_msg_type, m_message);
325}
326
327std::vector<uint8_t> Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
328 size_t frag_len,
329 uint16_t frag_offset,
330 uint16_t msg_len,
331 Handshake_Type type,
332 uint16_t msg_sequence) const {
333 std::vector<uint8_t> send_buf(12 + frag_len);
334
335 send_buf[0] = static_cast<uint8_t>(type);
336
337 store_be24(&send_buf[1], msg_len);
338
339 store_be(msg_sequence, &send_buf[4]);
340
341 store_be24(&send_buf[6], frag_offset);
342 store_be24(&send_buf[9], frag_len);
343
344 if(frag_len > 0) {
345 copy_mem(&send_buf[12], fragment, frag_len);
346 }
347
348 return send_buf;
349}
350
351std::vector<uint8_t> Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
352 Handshake_Type type,
353 uint16_t msg_sequence) const {
354 return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
355}
356
357std::vector<uint8_t> Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg, Handshake_Type type) const {
358 return format_w_seq(msg, type, m_in_message_seq - 1);
359}
360
361std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg) {
362 return this->send_under_epoch(msg, m_seqs.current_write_epoch());
363}
364
365std::vector<uint8_t> Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch) {
366 const std::vector<uint8_t> msg_bits = msg.serialize();
367 const Handshake_Type msg_type = msg.type();
368
369 if(msg_type == Handshake_Type::HandshakeCCS) {
370 m_send_hs(epoch, Record_Type::ChangeCipherSpec, msg_bits);
371 return std::vector<uint8_t>(); // not included in handshake hashes
372 } else if(msg_type == Handshake_Type::HelloVerifyRequest) {
373 // This message is not included in the handshake hashes
374 send_message(m_out_message_seq, epoch, msg_type, msg_bits);
375 m_out_message_seq += 1;
376 return std::vector<uint8_t>();
377 }
378
379 // Note: not saving CCS, instead we know it was there due to change in epoch
380 m_flights.rbegin()->push_back(m_out_message_seq);
381 m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
382
383 m_out_message_seq += 1;
384 m_last_write = steady_clock_ms();
385 m_next_timeout = m_initial_timeout;
386
387 return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
388}
389
390std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
391 uint16_t epoch,
392 Handshake_Type msg_type,
393 const std::vector<uint8_t>& msg_bits) {
394 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
395
396 auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq);
397
398 if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) {
399 m_send_hs(epoch, Record_Type::Handshake, no_fragment);
400 } else {
401 size_t frag_offset = 0;
402
403 /**
404 * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
405 * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
406 * over-estimate here. When CBC ciphers are removed this can be reduced
407 * since AEAD modes have no padding, at most 16 byte mac, and smaller
408 * per-record nonce.
409 */
410 const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
411 const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
412
413 if(m_mtu <= (header_overhead + ciphersuite_overhead)) {
414 throw Invalid_Argument("DTLS MTU is too small to send headers");
415 }
416
417 const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
418
419 while(frag_offset != msg_bits.size()) {
420 const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
421
422 const std::vector<uint8_t> frag = format_fragment(&msg_bits[frag_offset],
423 frag_len,
424 static_cast<uint16_t>(frag_offset),
425 static_cast<uint16_t>(msg_bits.size()),
426 msg_type,
427 msg_seq);
428
429 m_send_hs(epoch, Record_Type::Handshake, frag);
430
431 frag_offset += frag_len;
432 }
433 }
434
435 return no_fragment;
436}
437
438} // namespace Botan::TLS
#define BOTAN_ASSERT(expr, assertion_made)
Definition assert.h:62
std::vector< uint8_t > send_under_epoch(const Handshake_Message &msg, uint16_t epoch) override
void add_record(const uint8_t record[], size_t record_len, Record_Type type, uint64_t sequence_number) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
Protocol_Version initial_record_version() const override
std::vector< uint8_t > send(const Handshake_Message &msg) override
virtual Handshake_Type type() const =0
virtual std::vector< uint8_t > serialize() const =0
virtual Handshake_Type wire_type() const
std::vector< uint8_t > send_under_epoch(const Handshake_Message &msg, uint16_t epoch) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
Protocol_Version initial_record_version() const override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::vector< uint8_t > send(const Handshake_Message &msg) override
void add_record(const uint8_t record[], size_t record_len, Record_Type type, uint64_t sequence_number) override
@ DTLS_HEADER_SIZE
Definition tls_magic.h:27
constexpr uint8_t get_byte(T input)
Definition loadstor.h:79
constexpr void copy_mem(T *out, const T *in, size_t n)
Definition mem_ops.h:144
constexpr uint32_t make_uint32(uint8_t i0, uint8_t i1, uint8_t i2, uint8_t i3)
Definition loadstor.h:104
constexpr auto store_be(ParamTs &&... params)
Definition loadstor.h:745
constexpr auto load_be(ParamTs &&... params)
Definition loadstor.h:504