8#include <botan/internal/tls_handshake_io.h>
9#include <botan/internal/tls_record.h>
10#include <botan/internal/tls_seq_numbers.h>
11#include <botan/tls_messages.h>
12#include <botan/exceptn.h>
13#include <botan/loadstor.h>
22inline size_t load_be24(
const uint8_t q[3])
30void store_be24(uint8_t out[3],
size_t val)
32 out[0] =
get_byte(1,
static_cast<uint32_t
>(val));
33 out[1] =
get_byte(2,
static_cast<uint32_t
>(val));
34 out[2] =
get_byte(3,
static_cast<uint32_t
>(val));
37uint64_t steady_clock_ms()
39 return std::chrono::duration_cast<std::chrono::milliseconds>(
40 std::chrono::steady_clock::now().time_since_epoch()).count();
56 m_queue.insert(m_queue.end(), record, record + record_len);
60 if(record_len != 1 || record[0] != 1)
65 m_queue.insert(m_queue.end(), ccs_hs, ccs_hs +
sizeof(ccs_hs));
71std::pair<Handshake_Type, std::vector<uint8_t>>
74 if(m_queue.size() >= 4)
76 const size_t length = 4 +
make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
78 if(m_queue.size() >= length)
85 std::vector<uint8_t> contents(m_queue.begin() + 4,
86 m_queue.begin() + length);
88 m_queue.erase(m_queue.begin(), m_queue.begin() + length);
90 return std::make_pair(
type, contents);
101 std::vector<uint8_t> send_buf(4 + msg.size());
103 const size_t buf_size = msg.size();
105 send_buf[0] =
static_cast<uint8_t
>(
type);
107 store_be24(&send_buf[1], buf_size);
111 copy_mem(&send_buf[4], msg.data(), msg.size());
119 throw Invalid_State(
"Not possible to send under arbitrary epoch with stream based TLS");
124 const std::vector<uint8_t> msg_bits = msg.
serialize();
129 return std::vector<uint8_t>();
132 const std::vector<uint8_t> buf =
format(msg_bits, msg.
type());
142void Datagram_Handshake_IO::retransmit_last_flight()
144 const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145 retransmit_flight(flight_idx);
148void Datagram_Handshake_IO::retransmit_flight(
size_t flight_idx)
150 const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
152 BOTAN_ASSERT(flight.size() > 0,
"Nonempty flight to retransmit");
154 uint16_t epoch = m_flight_data[flight[0]].epoch;
156 for(
auto msg_seq : flight)
158 auto& msg = m_flight_data[msg_seq];
160 if(msg.epoch != epoch)
163 std::vector<uint8_t> ccs(1, 1);
167 send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
174 if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
183 const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
185 if(ms_since_write < m_next_timeout)
188 retransmit_last_flight();
190 m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
197 uint64_t record_sequence)
199 const uint16_t epoch =
static_cast<uint16_t
>(record_sequence >> 48);
203 if(record_len != 1 || record[0] != 1)
207 m_ccs_epochs.insert(epoch);
211 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
215 if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
218 const uint8_t msg_type = record[0];
219 const size_t msg_len = load_be24(&record[1]);
221 const size_t fragment_offset = load_be24(&record[6]);
222 const size_t fragment_length = load_be24(&record[9]);
224 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
226 if(record_len < total_size)
229 if(message_seq >= m_in_message_seq)
231 m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
243 record += total_size;
244 record_len -= total_size;
248std::pair<Handshake_Type, std::vector<uint8_t>>
252 if(!m_flights.rbegin()->empty())
253 m_flights.push_back(std::vector<uint16_t>());
257 if(!m_messages.empty())
259 const uint16_t current_epoch = m_messages.begin()->second.epoch();
261 if(m_ccs_epochs.count(current_epoch))
262 return std::make_pair(
HANDSHAKE_CCS, std::vector<uint8_t>());
267 auto i = m_messages.find(m_in_message_seq);
269 if(i == m_messages.end() || !i->second.complete())
274 m_in_message_seq += 1;
276 return i->second.message();
279void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
280 const uint8_t fragment[],
281 size_t fragment_length,
282 size_t fragment_offset,
293 m_msg_type = msg_type;
294 m_msg_length = msg_length;
297 if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
298 throw Decoding_Error(
"Inconsistent values in fragmented DTLS handshake header");
300 if(fragment_offset > m_msg_length)
303 if(fragment_offset + fragment_length > m_msg_length)
306 if(fragment_offset == 0 && fragment_length == m_msg_length)
309 m_message.assign(fragment, fragment+fragment_length);
321 for(
size_t i = 0; i != fragment_length; ++i)
322 m_fragments[fragment_offset+i] = fragment[i];
324 if(m_fragments.size() == m_msg_length)
326 m_message.resize(m_msg_length);
327 for(
size_t i = 0; i != m_msg_length; ++i)
328 m_message[i] = m_fragments[i];
334bool Datagram_Handshake_IO::Handshake_Reassembly::complete()
const
336 return (m_msg_type !=
HANDSHAKE_NONE && m_message.size() == m_msg_length);
339std::pair<Handshake_Type, std::vector<uint8_t>>
340Datagram_Handshake_IO::Handshake_Reassembly::message()
const
343 throw Internal_Error(
"Datagram_Handshake_IO - message not complete");
345 return std::make_pair(
static_cast<Handshake_Type>(m_msg_type), m_message);
349Datagram_Handshake_IO::format_fragment(
const uint8_t fragment[],
351 uint16_t frag_offset,
354 uint16_t msg_sequence)
const
356 std::vector<uint8_t> send_buf(12 + frag_len);
358 send_buf[0] =
static_cast<uint8_t
>(
type);
360 store_be24(&send_buf[1], msg_len);
362 store_be(msg_sequence, &send_buf[4]);
364 store_be24(&send_buf[6], frag_offset);
365 store_be24(&send_buf[9], frag_len);
369 copy_mem(&send_buf[12], fragment, frag_len);
376Datagram_Handshake_IO::format_w_seq(
const std::vector<uint8_t>& msg,
378 uint16_t msg_sequence)
const
380 return format_fragment(msg.data(), msg.size(), 0,
static_cast<uint16_t
>(msg.size()),
type, msg_sequence);
387 return format_w_seq(msg,
type, m_in_message_seq - 1);
398 const std::vector<uint8_t> msg_bits = msg.
serialize();
404 return std::vector<uint8_t>();
409 send_message(m_out_message_seq, epoch, msg_type, msg_bits);
410 m_out_message_seq += 1;
411 return std::vector<uint8_t>();
415 m_flights.rbegin()->push_back(m_out_message_seq);
416 m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
418 m_out_message_seq += 1;
419 m_last_write = steady_clock_ms();
420 m_next_timeout = m_initial_timeout;
422 return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
425std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
428 const std::vector<uint8_t>& msg_bits)
430 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
432 const std::vector<uint8_t> no_fragment =
433 format_w_seq(msg_bits, msg_type, msg_seq);
437 m_send_hs(epoch,
HANDSHAKE, no_fragment);
441 size_t frag_offset = 0;
450 const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
451 const size_t header_overhead =
DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
453 if(m_mtu <= (header_overhead + ciphersuite_overhead))
456 const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
458 while(frag_offset != msg_bits.size())
460 const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
462 const std::vector<uint8_t> frag =
463 format_fragment(&msg_bits[frag_offset],
465 static_cast<uint16_t
>(frag_offset),
466 static_cast<uint16_t
>(msg_bits.size()),
472 frag_offset += frag_len;
#define BOTAN_ASSERT(expr, assertion_made)
virtual uint16_t current_write_epoch() const =0
std::vector< uint8_t > send_under_epoch(const Handshake_Message &msg, uint16_t epoch) override
bool timeout_check() override
void add_record(const uint8_t record[], size_t record_len, Record_Type type, uint64_t sequence_number) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
Protocol_Version initial_record_version() const override
std::vector< uint8_t > send(const Handshake_Message &msg) override
virtual Handshake_Type type() const =0
virtual std::vector< uint8_t > serialize() const =0
std::vector< uint8_t > send_under_epoch(const Handshake_Message &msg, uint16_t epoch) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
Protocol_Version initial_record_version() const override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::vector< uint8_t > send(const Handshake_Message &msg) override
void add_record(const uint8_t record[], size_t record_len, Record_Type type, uint64_t sequence_number) override
std::string to_string(const BER_Object &obj)
void store_be(uint16_t in, uint8_t out[2])
constexpr uint32_t make_uint32(uint8_t i0, uint8_t i1, uint8_t i2, uint8_t i3)
void copy_mem(T *out, const T *in, size_t n)
uint16_t load_be< uint16_t >(const uint8_t in[], size_t off)
constexpr uint8_t get_byte(size_t byte_num, T input)