Botan  2.6.0
Crypto and TLS for C++11
tls_handshake_io.cpp
Go to the documentation of this file.
1 /*
2 * TLS Handshake IO
3 * (C) 2012,2014,2015 Jack Lloyd
4 *
5 * Botan is released under the Simplified BSD License (see license.txt)
6 */
7 
8 #include <botan/internal/tls_handshake_io.h>
9 #include <botan/internal/tls_record.h>
10 #include <botan/internal/tls_seq_numbers.h>
11 #include <botan/tls_messages.h>
12 #include <botan/exceptn.h>
13 #include <chrono>
14 
15 namespace Botan {
16 
17 namespace TLS {
18 
19 namespace {
20 
21 inline size_t load_be24(const uint8_t q[3])
22  {
23  return make_uint32(0,
24  q[0],
25  q[1],
26  q[2]);
27  }
28 
29 void store_be24(uint8_t out[3], size_t val)
30  {
31  out[0] = get_byte(1, static_cast<uint32_t>(val));
32  out[1] = get_byte(2, static_cast<uint32_t>(val));
33  out[2] = get_byte(3, static_cast<uint32_t>(val));
34  }
35 
36 uint64_t steady_clock_ms()
37  {
38  return std::chrono::duration_cast<std::chrono::milliseconds>(
39  std::chrono::steady_clock::now().time_since_epoch()).count();
40  }
41 
42 size_t split_for_mtu(size_t mtu, size_t msg_size)
43  {
44  const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers
45 
46  const size_t parts = (msg_size + mtu) / mtu;
47 
48  if(parts + DTLS_HEADERS_SIZE > mtu)
49  return parts + 1;
50 
51  return parts;
52  }
53 
54 }
55 
57  {
59  }
60 
61 void Stream_Handshake_IO::add_record(const std::vector<uint8_t>& record,
62  Record_Type record_type, uint64_t)
63  {
64  if(record_type == HANDSHAKE)
65  {
66  m_queue.insert(m_queue.end(), record.begin(), record.end());
67  }
68  else if(record_type == CHANGE_CIPHER_SPEC)
69  {
70  if(record.size() != 1 || record[0] != 1)
71  throw Decoding_Error("Invalid ChangeCipherSpec");
72 
73  // Pretend it's a regular handshake message of zero length
74  const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
75  m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
76  }
77  else
78  throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
79  }
80 
81 std::pair<Handshake_Type, std::vector<uint8_t>>
83  {
84  if(m_queue.size() >= 4)
85  {
86  const size_t length = make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
87 
88  if(m_queue.size() >= length + 4)
89  {
90  Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
91 
92  std::vector<uint8_t> contents(m_queue.begin() + 4,
93  m_queue.begin() + 4 + length);
94 
95  m_queue.erase(m_queue.begin(), m_queue.begin() + 4 + length);
96 
97  return std::make_pair(type, contents);
98  }
99  }
100 
101  return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
102  }
103 
104 std::vector<uint8_t>
105 Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
106  Handshake_Type type) const
107  {
108  std::vector<uint8_t> send_buf(4 + msg.size());
109 
110  const size_t buf_size = msg.size();
111 
112  send_buf[0] = type;
113 
114  store_be24(&send_buf[1], buf_size);
115 
116  if (msg.size() > 0)
117  {
118  copy_mem(&send_buf[4], msg.data(), msg.size());
119  }
120 
121  return send_buf;
122  }
123 
124 std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
125  {
126  const std::vector<uint8_t> msg_bits = msg.serialize();
127 
128  if(msg.type() == HANDSHAKE_CCS)
129  {
130  m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
131  return std::vector<uint8_t>(); // not included in handshake hashes
132  }
133 
134  const std::vector<uint8_t> buf = format(msg_bits, msg.type());
135  m_send_hs(HANDSHAKE, buf);
136  return buf;
137  }
138 
140  {
142  }
143 
144 void Datagram_Handshake_IO::retransmit_last_flight()
145  {
146  const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
147  retransmit_flight(flight_idx);
148  }
149 
150 void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
151  {
152  const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
153 
154  BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
155 
156  uint16_t epoch = m_flight_data[flight[0]].epoch;
157 
158  for(auto msg_seq : flight)
159  {
160  auto& msg = m_flight_data[msg_seq];
161 
162  if(msg.epoch != epoch)
163  {
164  // Epoch gap: insert the CCS
165  std::vector<uint8_t> ccs(1, 1);
166  m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
167  }
168 
169  send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
170  epoch = msg.epoch;
171  }
172  }
173 
175  {
176  if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
177  {
178  /*
179  If we haven't written anything yet obviously no timeout.
180  Also no timeout possible if we are mid-flight,
181  */
182  return false;
183  }
184 
185  const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
186 
187  if(ms_since_write < m_next_timeout)
188  return false;
189 
190  retransmit_last_flight();
191 
192  m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
193  return true;
194  }
195 
196 void Datagram_Handshake_IO::add_record(const std::vector<uint8_t>& record,
197  Record_Type record_type,
198  uint64_t record_sequence)
199  {
200  const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
201 
202  if(record_type == CHANGE_CIPHER_SPEC)
203  {
204  // TODO: check this is otherwise empty
205  m_ccs_epochs.insert(epoch);
206  return;
207  }
208 
209  const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
210 
211  const uint8_t* record_bits = record.data();
212  size_t record_size = record.size();
213 
214  while(record_size)
215  {
216  if(record_size < DTLS_HANDSHAKE_HEADER_LEN)
217  return; // completely bogus? at least degenerate/weird
218 
219  const uint8_t msg_type = record_bits[0];
220  const size_t msg_len = load_be24(&record_bits[1]);
221  const uint16_t message_seq = load_be<uint16_t>(&record_bits[4], 0);
222  const size_t fragment_offset = load_be24(&record_bits[6]);
223  const size_t fragment_length = load_be24(&record_bits[9]);
224 
225  const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
226 
227  if(record_size < total_size)
228  throw Decoding_Error("Bad lengths in DTLS header");
229 
230  if(message_seq >= m_in_message_seq)
231  {
232  m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
233  fragment_length,
234  fragment_offset,
235  epoch,
236  msg_type,
237  msg_len);
238  }
239  else
240  {
241  // TODO: detect retransmitted flight
242  }
243 
244  record_bits += total_size;
245  record_size -= total_size;
246  }
247  }
248 
249 std::pair<Handshake_Type, std::vector<uint8_t>>
251  {
252  // Expecting a message means the last flight is concluded
253  if(!m_flights.rbegin()->empty())
254  m_flights.push_back(std::vector<uint16_t>());
255 
256  if(expecting_ccs)
257  {
258  if(!m_messages.empty())
259  {
260  const uint16_t current_epoch = m_messages.begin()->second.epoch();
261 
262  if(m_ccs_epochs.count(current_epoch))
263  return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
264  }
265  return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
266  }
267 
268  auto i = m_messages.find(m_in_message_seq);
269 
270  if(i == m_messages.end() || !i->second.complete())
271  return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
272 
273  m_in_message_seq += 1;
274 
275  return i->second.message();
276  }
277 
278 void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
279  const uint8_t fragment[],
280  size_t fragment_length,
281  size_t fragment_offset,
282  uint16_t epoch,
283  uint8_t msg_type,
284  size_t msg_length)
285  {
286  if(complete())
287  return; // already have entire message, ignore this
288 
289  if(m_msg_type == HANDSHAKE_NONE)
290  {
291  m_epoch = epoch;
292  m_msg_type = msg_type;
293  m_msg_length = msg_length;
294  }
295 
296  if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
297  throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
298 
299  if(fragment_offset > m_msg_length)
300  throw Decoding_Error("Fragment offset past end of message");
301 
302  if(fragment_offset + fragment_length > m_msg_length)
303  throw Decoding_Error("Fragment overlaps past end of message");
304 
305  if(fragment_offset == 0 && fragment_length == m_msg_length)
306  {
307  m_fragments.clear();
308  m_message.assign(fragment, fragment+fragment_length);
309  }
310  else
311  {
312  /*
313  * FIXME. This is a pretty lame way to do defragmentation, huge
314  * overhead with a tree node per byte.
315  *
316  * Also should confirm that all overlaps have no changes,
317  * otherwise we expose ourselves to the classic fingerprinting
318  * and IDS evasion attacks on IP fragmentation.
319  */
320  for(size_t i = 0; i != fragment_length; ++i)
321  m_fragments[fragment_offset+i] = fragment[i];
322 
323  if(m_fragments.size() == m_msg_length)
324  {
325  m_message.resize(m_msg_length);
326  for(size_t i = 0; i != m_msg_length; ++i)
327  m_message[i] = m_fragments[i];
328  m_fragments.clear();
329  }
330  }
331  }
332 
333 bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
334  {
335  return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
336  }
337 
338 std::pair<Handshake_Type, std::vector<uint8_t>>
339 Datagram_Handshake_IO::Handshake_Reassembly::message() const
340  {
341  if(!complete())
342  throw Internal_Error("Datagram_Handshake_IO - message not complete");
343 
344  return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
345  }
346 
347 std::vector<uint8_t>
348 Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
349  size_t frag_len,
350  uint16_t frag_offset,
351  uint16_t msg_len,
353  uint16_t msg_sequence) const
354  {
355  std::vector<uint8_t> send_buf(12 + frag_len);
356 
357  send_buf[0] = type;
358 
359  store_be24(&send_buf[1], msg_len);
360 
361  store_be(msg_sequence, &send_buf[4]);
362 
363  store_be24(&send_buf[6], frag_offset);
364  store_be24(&send_buf[9], frag_len);
365 
366  if (frag_len > 0)
367  {
368  copy_mem(&send_buf[12], fragment, frag_len);
369  }
370 
371  return send_buf;
372  }
373 
374 std::vector<uint8_t>
375 Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
377  uint16_t msg_sequence) const
378  {
379  return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
380  }
381 
382 std::vector<uint8_t>
383 Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
384  Handshake_Type type) const
385  {
386  return format_w_seq(msg, type, m_in_message_seq - 1);
387  }
388 
389 
390 std::vector<uint8_t>
392  {
393  const std::vector<uint8_t> msg_bits = msg.serialize();
394  const uint16_t epoch = m_seqs.current_write_epoch();
395  const Handshake_Type msg_type = msg.type();
396 
397  if(msg_type == HANDSHAKE_CCS)
398  {
399  m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
400  return std::vector<uint8_t>(); // not included in handshake hashes
401  }
402 
403  // Note: not saving CCS, instead we know it was there due to change in epoch
404  m_flights.rbegin()->push_back(m_out_message_seq);
405  m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
406 
407  m_out_message_seq += 1;
408  m_last_write = steady_clock_ms();
409  m_next_timeout = m_initial_timeout;
410 
411  return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
412  }
413 
414 std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
415  uint16_t epoch,
416  Handshake_Type msg_type,
417  const std::vector<uint8_t>& msg_bits)
418  {
419  const std::vector<uint8_t> no_fragment =
420  format_w_seq(msg_bits, msg_type, msg_seq);
421 
422  if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
423  {
424  m_send_hs(epoch, HANDSHAKE, no_fragment);
425  }
426  else
427  {
428  const size_t parts = split_for_mtu(m_mtu, msg_bits.size());
429 
430  const size_t parts_size = (msg_bits.size() + parts) / parts;
431 
432  size_t frag_offset = 0;
433 
434  while(frag_offset != msg_bits.size())
435  {
436  const size_t frag_len =
437  std::min<size_t>(msg_bits.size() - frag_offset,
438  parts_size);
439 
440  m_send_hs(epoch,
441  HANDSHAKE,
442  format_fragment(&msg_bits[frag_offset],
443  frag_len,
444  static_cast<uint16_t>(frag_offset),
445  static_cast<uint16_t>(msg_bits.size()),
446  msg_type,
447  msg_seq));
448 
449  frag_offset += frag_len;
450  }
451  }
452 
453  return no_fragment;
454  }
455 
456 }
457 }
void store_be(uint16_t in, uint8_t out[2])
Definition: loadstor.h:434
Protocol_Version initial_record_version() const override
uint16_t load_be< uint16_t >(const uint8_t in[], size_t off)
Definition: loadstor.h:137
Protocol_Version initial_record_version() const override
std::vector< uint8_t > send(const Handshake_Message &msg) override
std::string to_string(const BER_Object &obj)
Definition: asn1_obj.cpp:145
MechanismType type
#define BOTAN_ASSERT(expr, assertion_made)
Definition: assert.h:30
std::vector< uint8_t > send(const Handshake_Message &msg) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
void add_record(const std::vector< uint8_t > &record, Record_Type type, uint64_t sequence_number) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
void copy_mem(T *out, const T *in, size_t n)
Definition: mem_ops.h:108
Definition: alg_id.cpp:13
virtual Handshake_Type type() const =0
virtual uint16_t current_write_epoch() const =0
uint8_t get_byte(size_t byte_num, T input)
Definition: loadstor.h:39
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
void add_record(const std::vector< uint8_t > &record, Record_Type type, uint64_t sequence_number) override
virtual std::vector< uint8_t > serialize() const =0
uint32_t make_uint32(uint8_t i0, uint8_t i1, uint8_t i2, uint8_t i3)
Definition: loadstor.h:65