11#include <botan/tls_extensions.h>
12#include <botan/internal/tls_reader.h>
13#include <botan/internal/stl_util.h>
14#include <botan/tls_exceptn.h>
15#include <botan/tls_policy.h>
16#include <botan/ber_dec.h>
24std::unique_ptr<Extension> make_extension(TLS_Data_Reader& reader,
31 const uint16_t size =
static_cast<uint16_t
>(reader.remaining_bytes());
35 return std::make_unique<Server_Name_Indicator>(reader, size);
38 return std::make_unique<Supported_Groups>(reader, size);
41 return std::make_unique<Certificate_Status_Request>(reader, size, message_type, from);
44 return std::make_unique<Supported_Point_Formats>(reader, size);
47 return std::make_unique<Renegotiation_Extension>(reader, size);
50 return std::make_unique<Signature_Algorithms>(reader, size);
53 return std::make_unique<Signature_Algorithms_Cert>(reader, size);
56 return std::make_unique<SRTP_Protection_Profiles>(reader, size);
59 return std::make_unique<Application_Layer_Protocol_Notification>(reader, size, from);
62 return std::make_unique<Extended_Master_Secret>(reader, size);
65 return std::make_unique<Record_Size_Limit>(reader, size, from);
68 return std::make_unique<Encrypt_then_MAC>(reader, size);
71 return std::make_unique<Session_Ticket_Extension>(reader, size);
74 return std::make_unique<Supported_Versions>(reader, size, from);
76#if defined(BOTAN_HAS_TLS_13)
77 case Extension_Code::PresharedKey:
78 return std::make_unique<PSK>(reader, size, message_type);
80 case Extension_Code::EarlyData:
81 return std::make_unique<EarlyDataIndication>(reader, size, message_type);
83 case Extension_Code::Cookie:
84 return std::make_unique<Cookie>(reader, size);
86 case Extension_Code::PskKeyExchangeModes:
87 return std::make_unique<PSK_Key_Exchange_Modes>(reader, size);
89 case Extension_Code::CertificateAuthorities:
90 return std::make_unique<Certificate_Authorities>(reader, size);
92 case Extension_Code::KeyShare:
93 return std::make_unique<Key_Share>(reader, size, message_type);
97 return std::make_unique<Unknown_Extension>(
static_cast<Extension_Code>(code),
105 if (
has(extn->type()))
108 std::to_string(
static_cast<uint16_t
>(extn->type())));
111 m_extensions.emplace_back(extn.release());
135 "Peer sent duplicated extensions");
140 const std::vector<uint8_t> extn_data = reader.
get_fixed<uint8_t>(extension_size);
142 this->
add(make_extension(extn_reader, type, from, message_type));
143 extn_reader.assert_done();
149 const bool allow_unknown_extensions)
const
153 std::vector<Extension_Code> diff;
154 std::set_difference(found.cbegin(), found.end(),
155 allowed_extensions.cbegin(), allowed_extensions.cend(),
156 std::back_inserter(diff));
158 if(allow_unknown_extensions)
162 const auto itr = std::find_if(diff.cbegin(), diff.cend(),
163 [
this](
const auto ext_type)
165 const auto ext = get(ext_type);
166 return ext && ext->is_implemented();
170 return itr != diff.cend();
173 return !diff.empty();
178 const auto i = std::find_if(m_extensions.begin(), m_extensions.end(),
179 [type](
const auto &ext) {
180 return ext->type() == type;
183 std::unique_ptr<Extension> result;
184 if (i != m_extensions.end())
186 std::swap(result, *i);
187 m_extensions.erase(i);
195 std::vector<uint8_t> buf(2);
197 for(
const auto& extn : m_extensions)
202 const uint16_t extn_code =
static_cast<uint16_t
>(extn->type());
204 const std::vector<uint8_t> extn_val = extn->serialize(whoami);
206 buf.push_back(get_byte<0>(extn_code));
207 buf.push_back(get_byte<1>(extn_code));
209 buf.push_back(get_byte<0>(
static_cast<uint16_t
>(extn_val.size())));
210 buf.push_back(get_byte<1>(
static_cast<uint16_t
>(extn_val.size())));
215 const uint16_t extn_size =
static_cast<uint16_t
>(buf.size() - 2);
217 buf[0] = get_byte<0>(extn_size);
218 buf[1] = get_byte<1>(extn_size);
222 return std::vector<uint8_t>();
229 std::set<Extension_Code> offers;
230 std::transform(m_extensions.cbegin(), m_extensions.cend(),
231 std::inserter(offers, offers.begin()), [] (
const auto &ext) {
239 uint16_t extension_size) :
241 m_value(reader.get_fixed<uint8_t>(extension_size))
251 uint16_t extension_size)
256 if(extension_size == 0)
261 if(name_bytes + 2 != extension_size)
266 uint8_t name_type = reader.
get_byte();
271 m_sni_host_name = reader.
get_string(2, 1, 65535);
272 name_bytes -=
static_cast<uint16_t
>(2 + m_sni_host_name.size());
293 std::vector<uint8_t> buf;
295 size_t name_len = m_sni_host_name.size();
297 buf.push_back(get_byte<0>(
static_cast<uint16_t
>(name_len+3)));
298 buf.push_back(get_byte<1>(
static_cast<uint16_t
>(name_len+3)));
301 buf.push_back(get_byte<0>(
static_cast<uint16_t
>(name_len)));
302 buf.push_back(get_byte<1>(
static_cast<uint16_t
>(name_len)));
304 buf += std::make_pair(
306 m_sni_host_name.size());
312 uint16_t extension_size) : m_reneg_data(reader.get_range<uint8_t>(1, 0, 255))
314 if(m_reneg_data.size() + 1 != extension_size)
315 throw Decoding_Error(
"Bad encoding for secure renegotiation extn");
320 std::vector<uint8_t> buf;
326 uint16_t extension_size,
329 if(extension_size == 0)
334 size_t bytes_remaining = extension_size - 2;
336 if(name_bytes != bytes_remaining)
337 throw Decoding_Error(
"Bad encoding of ALPN extension, bad length field");
339 while(bytes_remaining)
341 const std::string p = reader.
get_string(1, 0, 255);
343 if(bytes_remaining < p.size() + 1)
344 throw Decoding_Error(
"Bad encoding of ALPN, length field too long");
349 bytes_remaining -= (p.size() + 1);
351 m_protocols.push_back(p);
361 "Server sent " + std::to_string(m_protocols.size()) +
362 " protocols in ALPN extension response");
369 return m_protocols.front();
374 std::vector<uint8_t> buf(2);
376 for(
auto&& p: m_protocols)
378 if(p.length() >= 256)
379 throw TLS_Exception(Alert::InternalError,
"ALPN name too long");
387 buf[0] = get_byte<0>(
static_cast<uint16_t
>(buf.size()-2));
388 buf[1] = get_byte<1>(
static_cast<uint16_t
>(buf.size()-2));
404 std::vector<Group_Params> ec;
405 for(
auto g : m_groups)
415 std::vector<Group_Params> dh;
416 for(
auto g : m_groups)
426 std::vector<uint8_t> buf(2);
428 for(
auto g : m_groups)
430 const uint16_t
id =
static_cast<uint16_t
>(g);
434 buf.push_back(get_byte<0>(
id));
435 buf.push_back(get_byte<1>(
id));
439 buf[0] = get_byte<0>(
static_cast<uint16_t
>(buf.size()-2));
440 buf[1] = get_byte<1>(
static_cast<uint16_t
>(buf.size()-2));
446 uint16_t extension_size)
450 if(len + 2 != extension_size)
451 throw Decoding_Error(
"Inconsistent length field in supported groups list");
456 const size_t elems = len / 2;
458 for(
size_t i = 0; i != elems; ++i)
463 m_groups.push_back(group);
470 if(m_prefers_compressed)
481 uint16_t extension_size)
485 if(len + 1 != extension_size)
486 throw Decoding_Error(
"Inconsistent length field in supported point formats list");
488 for(
size_t i = 0; i != len; ++i)
494 m_prefers_compressed =
false;
500 m_prefers_compressed =
true;
511std::vector<uint8_t> serialize_signature_algorithms(
const std::vector<Signature_Scheme>& schemes)
513 BOTAN_ASSERT(schemes.size() < 256,
"Too many signature schemes");
515 std::vector<uint8_t> buf;
517 const uint16_t len =
static_cast<uint16_t
>(schemes.size() * 2);
519 buf.push_back(get_byte<0>(len));
520 buf.push_back(get_byte<1>(len));
524 buf.push_back(get_byte<0>(scheme.wire_code()));
525 buf.push_back(get_byte<1>(scheme.wire_code()));
531std::vector<Signature_Scheme> parse_signature_algorithms(TLS_Data_Reader& reader,
532 uint16_t extension_size)
534 uint16_t len = reader.get_uint16_t();
536 if(len + 2 != extension_size || len % 2 == 1 || len == 0)
538 throw Decoding_Error(
"Bad encoding on signature algorithms extension");
541 std::vector<Signature_Scheme> schemes;
542 schemes.reserve(len / 2);
545 schemes.emplace_back(reader.get_uint16_t());
556 return serialize_signature_algorithms(m_schemes);
560 uint16_t extension_size)
561 : m_schemes(parse_signature_algorithms(reader, extension_size))
566 return serialize_signature_algorithms(m_schemes);
570 uint16_t extension_size)
571 : m_schemes(parse_signature_algorithms(reader, extension_size))
575 uint16_t extension_size)
576 : m_ticket(
Session_Ticket(reader.get_elem<uint8_t,
std::vector<uint8_t>>(extension_size)))
580 uint16_t extension_size) : m_pp(reader.get_range<uint16_t>(2, 0, 65535))
582 const std::vector<uint8_t> mki = reader.
get_range<uint8_t>(1, 0, 255);
584 if(m_pp.size() * 2 + mki.size() + 3 != extension_size)
585 throw Decoding_Error(
"Bad encoding for SRTP protection extension");
588 throw Decoding_Error(
"Unhandled non-empty MKI for SRTP protection extension");
593 std::vector<uint8_t> buf;
595 const uint16_t pp_len =
static_cast<uint16_t
>(m_pp.size() * 2);
596 buf.push_back(get_byte<0>(pp_len));
597 buf.push_back(get_byte<1>(pp_len));
599 for(uint16_t pp : m_pp)
601 buf.push_back(get_byte<0>(pp));
602 buf.push_back(get_byte<1>(pp));
611 uint16_t extension_size)
613 if(extension_size != 0)
619 return std::vector<uint8_t>();
623 uint16_t extension_size)
625 if(extension_size != 0)
631 return std::vector<uint8_t>();
636 std::vector<uint8_t> buf;
641 buf.push_back(m_versions[0].major_version());
642 buf.push_back(m_versions[0].minor_version());
647 const uint8_t len =
static_cast<uint8_t
>(m_versions.size() * 2);
653 buf.push_back(version.major_version());
654 buf.push_back(version.minor_version());
666#if defined(BOTAN_HAS_TLS_12)
667 if(offer >= Protocol_Version::DTLS_V12 && policy.
allow_dtls12())
668 m_versions.push_back(Protocol_Version::DTLS_V12);
673#if defined(BOTAN_HAS_TLS_13)
674 if(offer >= Protocol_Version::TLS_V13 && policy.
allow_tls13())
675 m_versions.push_back(Protocol_Version::TLS_V13);
677#if defined(BOTAN_HAS_TLS_12)
678 if(offer >= Protocol_Version::TLS_V12 && policy.
allow_tls12())
679 m_versions.push_back(Protocol_Version::TLS_V12);
685 uint16_t extension_size,
690 if(extension_size != 2)
691 throw Decoding_Error(
"Server sent invalid supported_versions extension");
701 if(extension_size != 1+2*
versions.size())
702 throw Decoding_Error(
"Client sent invalid supported_versions extension");
708 for(
auto v : m_versions)
719 "RFC 8449 does not allow record size limits smaller than 64 bytes");
721 "RFC 8449 does not allow record size limits larger than 2^14+1");
725 uint16_t extension_size,
728 if(extension_size != 2)
730 throw TLS_Exception(Alert::DecodeError,
"invalid record_size_limit extension");
751 "Server requested a record size limit larger than the protocol's maximum");
761 "Received a record size limit smaller than 64 bytes");
767 std::vector<uint8_t> buf;
769 buf.push_back(get_byte<0>(m_limit));
770 buf.push_back(get_byte<1>(m_limit));
776#if defined(BOTAN_HAS_TLS_13)
777Cookie::Cookie(
const std::vector<uint8_t>& cookie) :
782Cookie::Cookie(TLS_Data_Reader& reader,
783 uint16_t extension_size)
785 if (extension_size == 0)
790 const uint16_t len = reader.get_uint16_t();
795 throw Decoding_Error(
"Cookie length must be at least 1 byte");
798 if (len > reader.remaining_bytes())
800 throw Decoding_Error(
"Not enough bytes in the buffer to decode Cookie");
803 for(
size_t i = 0; i < len; ++i)
805 m_cookie.push_back(reader.get_byte());
811 std::vector<uint8_t> buf;
813 const uint16_t len =
static_cast<uint16_t
>(m_cookie.size());
815 buf.push_back(get_byte<0>(len));
816 buf.push_back(get_byte<1>(len));
818 for (
const auto& cookie_byte : m_cookie)
820 buf.push_back(cookie_byte);
827std::vector<uint8_t> PSK_Key_Exchange_Modes::serialize(
Connection_Side)
const
829 std::vector<uint8_t> buf;
832 buf.push_back(
static_cast<uint8_t
>(m_modes.size()));
833 for (
const auto& mode : m_modes)
835 buf.push_back(
static_cast<uint8_t
>(mode));
841PSK_Key_Exchange_Modes::PSK_Key_Exchange_Modes(TLS_Data_Reader& reader, uint16_t extension_size)
843 if (extension_size < 2)
845 throw Decoding_Error(
"Empty psk_key_exchange_modes extension is illegal");
848 const auto mode_count = reader.get_byte();
849 for(uint16_t i = 0; i < mode_count; ++i)
851 const auto mode =
static_cast<PSK_Key_Exchange_Mode
>(reader.get_byte());
852 if (mode == PSK_Key_Exchange_Mode::PSK_KE ||
853 mode == PSK_Key_Exchange_Mode::PSK_DHE_KE)
855 m_modes.push_back(mode);
861std::vector<uint8_t> Certificate_Authorities::serialize(
Connection_Side)
const
862 {
throw Not_Implemented(
"serializing Certificate_Authorities is NYI"); }
864Certificate_Authorities::Certificate_Authorities(TLS_Data_Reader& reader, uint16_t extension_size)
866 if (extension_size < 2)
868 throw Decoding_Error(
"Empty certificate_authorities extension is illegal");
871 const uint16_t purported_size = reader.get_uint16_t();
873 if(reader.remaining_bytes() != purported_size)
874 throw Decoding_Error(
"Inconsistent length in certificate_authorities extension");
876 while(reader.has_remaining())
878 std::vector<uint8_t> name_bits = reader.get_tls_length_value(2);
880 BER_Decoder decoder(name_bits.data(), name_bits.size());
881 m_distinguished_names.emplace_back();
882 decoder.decode(m_distinguished_names.back());
886Certificate_Authorities::Certificate_Authorities(std::vector<X509_DN> acceptable_DNs)
887 : m_distinguished_names(
std::move(acceptable_DNs)) {}
890std::vector<uint8_t> EarlyDataIndication::serialize(
Connection_Side)
const
892 std::vector<uint8_t> result;
893 if(m_max_early_data_size.has_value())
895 const auto max_data = m_max_early_data_size.value();
896 result.push_back(get_byte<0>(max_data));
897 result.push_back(get_byte<1>(max_data));
898 result.push_back(get_byte<2>(max_data));
899 result.push_back(get_byte<3>(max_data));
904EarlyDataIndication::EarlyDataIndication(TLS_Data_Reader& reader,
905 uint16_t extension_size,
910 if(extension_size != 4)
912 throw TLS_Exception(Alert::DecodeError,
913 "Received an early_data extension in a NewSessionTicket message "
914 "without maximum early data size indication");
917 m_max_early_data_size = reader.get_uint32_t();
919 else if(extension_size != 0)
921 throw TLS_Exception(Alert::DecodeError,
922 "Received an early_data extension containing an unexpected data "
927bool EarlyDataIndication::empty()
const
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_STATE_CHECK(expr)
#define BOTAN_ASSERT(expr, assertion_made)
Application_Layer_Protocol_Notification(std::string_view protocol)
std::string single_protocol() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Encrypt_then_MAC()=default
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Extended_Master_Secret()=default
std::vector< uint8_t > serialize(Connection_Side whoami) const
bool contains_other_than(const std::set< Extension_Code > &allowed_extensions, const bool allow_unknown_extensions=false) const
void deserialize(TLS_Data_Reader &reader, const Connection_Side from, const Handshake_Type message_type)
std::set< Extension_Code > extension_types() const
void add(std::unique_ptr< Extension > extn)
virtual bool allow_tls12() const
virtual bool allow_tls13() const
virtual bool allow_dtls12() const
bool is_datagram_protocol() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Record_Size_Limit(const uint16_t limit)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Renegotiation_Extension()=default
SRTP_Protection_Profiles(const std::vector< uint16_t > &pp)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Server_Name_Indicator(std::string_view host_name)
Session_Ticket_Extension()=default
Signature_Algorithms_Cert(std::vector< Signature_Scheme > schemes)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Signature_Algorithms(std::vector< Signature_Scheme > schemes)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Supported_Groups(const std::vector< Group_Params > &groups)
std::vector< Group_Params > ec_groups() const
const std::vector< Group_Params > & groups() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< Group_Params > dh_groups() const
bool supports(Protocol_Version version) const
Supported_Versions(Protocol_Version version, const Policy &policy)
const std::vector< Protocol_Version > & versions() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::string get_string(size_t len_bytes, size_t min_bytes, size_t max_bytes)
bool has_remaining() const
void discard_next(size_t bytes)
std::vector< T > get_range(size_t len_bytes, size_t min_elems, size_t max_elems)
size_t remaining_bytes() const
std::vector< T > get_fixed(size_t size)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Unknown_Extension(Extension_Code type, TLS_Data_Reader &reader, uint16_t extension_size)
void append_tls_length_value(std::vector< uint8_t, Alloc > &buf, const T *vals, size_t vals_size, size_t tag_size)
@ CertSignatureAlgorithms
@ ApplicationLayerProtocolNegotiation
@ CertificateStatusRequest
bool group_param_is_dh(Group_Params group)
bool value_exists(const std::vector< T > &vec, const OT &val)
const uint8_t * cast_char_ptr_to_uint8(const char *s)