12#include <botan/tls_extensions.h>
14#include <botan/ber_dec.h>
15#include <botan/der_enc.h>
16#include <botan/tls_exceptn.h>
17#include <botan/tls_policy.h>
18#include <botan/internal/stl_util.h>
19#include <botan/internal/tls_reader.h>
27std::unique_ptr<Extension> make_extension(TLS_Data_Reader& reader,
33 const uint16_t size =
static_cast<uint16_t
>(reader.remaining_bytes());
36 return std::make_unique<Server_Name_Indicator>(reader, size);
39 return std::make_unique<Supported_Groups>(reader, size);
42 return std::make_unique<Certificate_Status_Request>(reader, size, message_type, from);
45 return std::make_unique<Supported_Point_Formats>(reader, size);
48 return std::make_unique<Renegotiation_Extension>(reader, size);
51 return std::make_unique<Signature_Algorithms>(reader, size);
54 return std::make_unique<Signature_Algorithms_Cert>(reader, size);
57 return std::make_unique<SRTP_Protection_Profiles>(reader, size);
60 return std::make_unique<Application_Layer_Protocol_Notification>(reader, size, from);
63 return std::make_unique<Extended_Master_Secret>(reader, size);
66 return std::make_unique<Record_Size_Limit>(reader, size, from);
69 return std::make_unique<Encrypt_then_MAC>(reader, size);
72 return std::make_unique<Session_Ticket_Extension>(reader, size);
75 return std::make_unique<Supported_Versions>(reader, size, from);
77#if defined(BOTAN_HAS_TLS_13)
78 case Extension_Code::PresharedKey:
79 return std::make_unique<PSK>(reader, size, message_type);
81 case Extension_Code::EarlyData:
82 return std::make_unique<EarlyDataIndication>(reader, size, message_type);
84 case Extension_Code::Cookie:
85 return std::make_unique<Cookie>(reader, size);
87 case Extension_Code::PskKeyExchangeModes:
88 return std::make_unique<PSK_Key_Exchange_Modes>(reader, size);
90 case Extension_Code::CertificateAuthorities:
91 return std::make_unique<Certificate_Authorities>(reader, size);
93 case Extension_Code::KeyShare:
94 return std::make_unique<Key_Share>(reader, size, message_type);
98 return std::make_unique<Unknown_Extension>(
static_cast<Extension_Code>(code), reader, size);
104 if(
has(extn->type())) {
106 std::to_string(
static_cast<uint16_t
>(extn->type())));
109 m_extensions.emplace_back(extn.release());
126 if(this->
has(type)) {
127 throw TLS_Exception(TLS::Alert::DecodeError,
"Peer sent duplicated extensions");
132 const std::vector<uint8_t> extn_data = reader.
get_fixed<uint8_t>(extension_size);
134 this->
add(make_extension(extn_reader, type, from, message_type));
135 extn_reader.assert_done();
141 const bool allow_unknown_extensions)
const {
144 std::vector<Extension_Code> diff;
146 found.cbegin(), found.end(), allowed_extensions.cbegin(), allowed_extensions.cend(), std::back_inserter(diff));
148 if(allow_unknown_extensions) {
151 const auto itr = std::find_if(diff.cbegin(), diff.cend(), [
this](
const auto ext_type) {
152 const auto ext = get(ext_type);
153 return ext && ext->is_implemented();
157 return itr != diff.cend();
160 return !diff.empty();
165 std::find_if(m_extensions.begin(), m_extensions.end(), [type](
const auto& ext) { return ext->type() == type; });
167 std::unique_ptr<Extension> result;
168 if(i != m_extensions.end()) {
169 std::swap(result, *i);
170 m_extensions.erase(i);
177 std::vector<uint8_t> buf(2);
179 for(
const auto& extn : m_extensions) {
184 const uint16_t extn_code =
static_cast<uint16_t
>(extn->type());
186 const std::vector<uint8_t> extn_val = extn->serialize(whoami);
188 buf.push_back(get_byte<0>(extn_code));
189 buf.push_back(get_byte<1>(extn_code));
191 buf.push_back(get_byte<0>(
static_cast<uint16_t
>(extn_val.size())));
192 buf.push_back(get_byte<1>(
static_cast<uint16_t
>(extn_val.size())));
197 const uint16_t extn_size =
static_cast<uint16_t
>(buf.size() - 2);
199 buf[0] = get_byte<0>(extn_size);
200 buf[1] = get_byte<1>(extn_size);
203 if(buf.size() == 2) {
204 return std::vector<uint8_t>();
211 std::set<Extension_Code> offers;
213 m_extensions.cbegin(), m_extensions.cend(), std::inserter(offers, offers.begin()), [](
const auto& ext) {
220 m_type(type), m_value(reader.get_fixed<uint8_t>(extension_size)) {}
230 if(extension_size == 0) {
236 if(name_bytes + 2 != extension_size) {
241 uint8_t name_type = reader.
get_byte();
246 m_sni_host_name = reader.
get_string(2, 1, 65535);
247 name_bytes -=
static_cast<uint16_t
>(2 + m_sni_host_name.size());
265 std::vector<uint8_t> buf;
267 size_t name_len = m_sni_host_name.size();
269 buf.push_back(get_byte<0>(
static_cast<uint16_t
>(name_len + 3)));
270 buf.push_back(get_byte<1>(
static_cast<uint16_t
>(name_len + 3)));
273 buf.push_back(get_byte<0>(
static_cast<uint16_t
>(name_len)));
274 buf.push_back(get_byte<1>(
static_cast<uint16_t
>(name_len)));
282 m_reneg_data(reader.get_range<uint8_t>(1, 0, 255)) {
283 if(m_reneg_data.size() + 1 != extension_size) {
284 throw Decoding_Error(
"Bad encoding for secure renegotiation extn");
289 std::vector<uint8_t> buf;
295 uint16_t extension_size,
297 if(extension_size == 0) {
303 size_t bytes_remaining = extension_size - 2;
305 if(name_bytes != bytes_remaining) {
306 throw Decoding_Error(
"Bad encoding of ALPN extension, bad length field");
309 while(bytes_remaining) {
310 const std::string p = reader.
get_string(1, 0, 255);
312 if(bytes_remaining < p.size() + 1) {
313 throw Decoding_Error(
"Bad encoding of ALPN, length field too long");
320 bytes_remaining -= (p.size() + 1);
322 m_protocols.push_back(p);
332 "Server sent " + std::to_string(m_protocols.size()) +
" protocols in ALPN extension response");
338 return m_protocols.front();
342 std::vector<uint8_t> buf(2);
344 for(
auto&& p : m_protocols) {
345 if(p.length() >= 256) {
346 throw TLS_Exception(Alert::InternalError,
"ALPN name too long");
353 buf[0] = get_byte<0>(
static_cast<uint16_t
>(buf.size() - 2));
354 buf[1] = get_byte<1>(
static_cast<uint16_t
>(buf.size() - 2));
366 std::vector<Group_Params> ec;
367 for(
auto g : m_groups) {
368 if(g.is_pure_ecc_group()) {
376 std::vector<Group_Params> dh;
377 for(
auto g : m_groups) {
378 if(g.is_dh_named_group()) {
386 std::vector<uint8_t> buf(2);
388 for(
auto g : m_groups) {
389 const uint16_t
id = g.wire_code();
392 buf.push_back(get_byte<0>(
id));
393 buf.push_back(get_byte<1>(
id));
397 buf[0] = get_byte<0>(
static_cast<uint16_t
>(buf.size() - 2));
398 buf[1] = get_byte<1>(
static_cast<uint16_t
>(buf.size() - 2));
406 if(len + 2 != extension_size) {
407 throw Decoding_Error(
"Inconsistent length field in supported groups list");
414 const size_t elems = len / 2;
416 for(
size_t i = 0; i != elems; ++i) {
420 m_groups.push_back(group);
427 if(m_prefers_compressed) {
437 if(len + 1 != extension_size) {
438 throw Decoding_Error(
"Inconsistent length field in supported point formats list");
441 bool includes_uncompressed =
false;
442 for(
size_t i = 0; i != len; ++i) {
446 m_prefers_compressed =
false;
450 m_prefers_compressed =
true;
451 std::vector<uint8_t> remaining_formats = reader.
get_fixed<uint8_t>(len - i - 1);
452 includes_uncompressed =
453 std::any_of(std::begin(remaining_formats), std::end(remaining_formats), [](uint8_t remaining_format) {
468 if(!includes_uncompressed) {
470 "Supported Point Formats Extension must contain the uncompressed point format");
476std::vector<uint8_t> serialize_signature_algorithms(
const std::vector<Signature_Scheme>& schemes) {
477 BOTAN_ASSERT(schemes.size() < 256,
"Too many signature schemes");
479 std::vector<uint8_t> buf;
481 const uint16_t len =
static_cast<uint16_t
>(schemes.size() * 2);
483 buf.push_back(get_byte<0>(len));
484 buf.push_back(get_byte<1>(len));
487 buf.push_back(get_byte<0>(scheme.wire_code()));
488 buf.push_back(get_byte<1>(scheme.wire_code()));
494std::vector<Signature_Scheme> parse_signature_algorithms(TLS_Data_Reader& reader, uint16_t extension_size) {
495 uint16_t len = reader.get_uint16_t();
497 if(len + 2 != extension_size || len % 2 == 1 || len == 0) {
498 throw Decoding_Error(
"Bad encoding on signature algorithms extension");
501 std::vector<Signature_Scheme> schemes;
502 schemes.reserve(len / 2);
504 schemes.emplace_back(reader.get_uint16_t());
514 return serialize_signature_algorithms(m_schemes);
518 m_schemes(parse_signature_algorithms(reader, extension_size)) {}
521 return serialize_signature_algorithms(m_schemes);
525 m_schemes(parse_signature_algorithms(reader, extension_size)) {}
528 m_ticket(
Session_Ticket(reader.get_elem<uint8_t, std::vector<uint8_t>>(extension_size))) {}
531 m_pp(reader.get_range<uint16_t>(2, 0, 65535)) {
532 const std::vector<uint8_t> mki = reader.
get_range<uint8_t>(1, 0, 255);
534 if(m_pp.size() * 2 + mki.size() + 3 != extension_size) {
535 throw Decoding_Error(
"Bad encoding for SRTP protection extension");
539 throw Decoding_Error(
"Unhandled non-empty MKI for SRTP protection extension");
544 std::vector<uint8_t> buf;
546 const uint16_t pp_len =
static_cast<uint16_t
>(m_pp.size() * 2);
547 buf.push_back(get_byte<0>(pp_len));
548 buf.push_back(get_byte<1>(pp_len));
550 for(uint16_t pp : m_pp) {
551 buf.push_back(get_byte<0>(pp));
552 buf.push_back(get_byte<1>(pp));
561 if(extension_size != 0) {
567 return std::vector<uint8_t>();
571 if(extension_size != 0) {
577 return std::vector<uint8_t>();
581 std::vector<uint8_t> buf;
585 buf.push_back(m_versions[0].major_version());
586 buf.push_back(m_versions[0].minor_version());
589 const uint8_t len =
static_cast<uint8_t
>(m_versions.size() * 2);
594 buf.push_back(version.major_version());
595 buf.push_back(version.minor_version());
604#if defined(BOTAN_HAS_TLS_12)
605 if(offer >= Protocol_Version::DTLS_V12 && policy.
allow_dtls12()) {
606 m_versions.push_back(Protocol_Version::DTLS_V12);
610#if defined(BOTAN_HAS_TLS_13)
611 if(offer >= Protocol_Version::TLS_V13 && policy.
allow_tls13()) {
612 m_versions.push_back(Protocol_Version::TLS_V13);
615#if defined(BOTAN_HAS_TLS_12)
616 if(offer >= Protocol_Version::TLS_V12 && policy.
allow_tls12()) {
617 m_versions.push_back(Protocol_Version::TLS_V12);
625 if(extension_size != 2) {
626 throw Decoding_Error(
"Server sent invalid supported_versions extension");
636 if(extension_size != 1 + 2 *
versions.size()) {
637 throw Decoding_Error(
"Client sent invalid supported_versions extension");
643 for(
auto v : m_versions) {
652 BOTAN_ASSERT(
limit >= 64,
"RFC 8449 does not allow record size limits smaller than 64 bytes");
654 "RFC 8449 does not allow record size limits larger than 2^14+1");
658 if(extension_size != 2) {
659 throw TLS_Exception(Alert::DecodeError,
"invalid record_size_limit extension");
679 "Server requested a record size limit larger than the protocol's maximum");
687 throw TLS_Exception(Alert::IllegalParameter,
"Received a record size limit smaller than 64 bytes");
692 std::vector<uint8_t> buf;
694 buf.push_back(get_byte<0>(m_limit));
695 buf.push_back(get_byte<1>(m_limit));
700#if defined(BOTAN_HAS_TLS_13)
701Cookie::Cookie(
const std::vector<uint8_t>& cookie) : m_cookie(cookie) {}
703Cookie::Cookie(TLS_Data_Reader& reader, uint16_t extension_size) {
704 if(extension_size == 0) {
708 const uint16_t len = reader.get_uint16_t();
712 throw Decoding_Error(
"Cookie length must be at least 1 byte");
715 if(len > reader.remaining_bytes()) {
716 throw Decoding_Error(
"Not enough bytes in the buffer to decode Cookie");
719 for(
size_t i = 0; i < len; ++i) {
720 m_cookie.push_back(reader.get_byte());
725 std::vector<uint8_t> buf;
727 const uint16_t len =
static_cast<uint16_t
>(m_cookie.size());
729 buf.push_back(get_byte<0>(len));
730 buf.push_back(get_byte<1>(len));
732 for(
const auto& cookie_byte : m_cookie) {
733 buf.push_back(cookie_byte);
739std::vector<uint8_t> PSK_Key_Exchange_Modes::serialize(
Connection_Side)
const {
740 std::vector<uint8_t> buf;
743 buf.push_back(
static_cast<uint8_t
>(m_modes.size()));
744 for(
const auto& mode : m_modes) {
745 buf.push_back(
static_cast<uint8_t
>(mode));
751PSK_Key_Exchange_Modes::PSK_Key_Exchange_Modes(TLS_Data_Reader& reader, uint16_t extension_size) {
752 if(extension_size < 2) {
753 throw Decoding_Error(
"Empty psk_key_exchange_modes extension is illegal");
756 const auto mode_count = reader.get_byte();
757 for(uint16_t i = 0; i < mode_count; ++i) {
758 const auto mode =
static_cast<PSK_Key_Exchange_Mode
>(reader.get_byte());
759 if(mode == PSK_Key_Exchange_Mode::PSK_KE || mode == PSK_Key_Exchange_Mode::PSK_DHE_KE) {
760 m_modes.push_back(mode);
765std::vector<uint8_t> Certificate_Authorities::serialize(
Connection_Side)
const {
766 std::vector<uint8_t> out;
767 std::vector<uint8_t> dn_list;
769 for(
const auto& dn : m_distinguished_names) {
770 std::vector<uint8_t> encoded_dn;
771 auto encoder = DER_Encoder(encoded_dn);
772 dn.encode_into(encoder);
781Certificate_Authorities::Certificate_Authorities(TLS_Data_Reader& reader, uint16_t extension_size) {
782 if(extension_size < 2) {
783 throw Decoding_Error(
"Empty certificate_authorities extension is illegal");
786 const uint16_t purported_size = reader.get_uint16_t();
788 if(reader.remaining_bytes() != purported_size) {
789 throw Decoding_Error(
"Inconsistent length in certificate_authorities extension");
792 while(reader.has_remaining()) {
793 std::vector<uint8_t> name_bits = reader.get_tls_length_value(2);
795 BER_Decoder decoder(name_bits.data(), name_bits.size());
796 m_distinguished_names.emplace_back();
797 decoder.decode(m_distinguished_names.back());
801Certificate_Authorities::Certificate_Authorities(std::vector<X509_DN> acceptable_DNs) :
802 m_distinguished_names(std::move(acceptable_DNs)) {}
804std::vector<uint8_t> EarlyDataIndication::serialize(
Connection_Side)
const {
805 std::vector<uint8_t> result;
806 if(m_max_early_data_size.has_value()) {
807 const auto max_data = m_max_early_data_size.value();
808 result.push_back(get_byte<0>(max_data));
809 result.push_back(get_byte<1>(max_data));
810 result.push_back(get_byte<2>(max_data));
811 result.push_back(get_byte<3>(max_data));
816EarlyDataIndication::EarlyDataIndication(TLS_Data_Reader& reader,
817 uint16_t extension_size,
820 if(extension_size != 4) {
821 throw TLS_Exception(Alert::DecodeError,
822 "Received an early_data extension in a NewSessionTicket message "
823 "without maximum early data size indication");
826 m_max_early_data_size = reader.get_uint32_t();
827 }
else if(extension_size != 0) {
828 throw TLS_Exception(Alert::DecodeError,
829 "Received an early_data extension containing an unexpected data "
834bool EarlyDataIndication::empty()
const {
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_STATE_CHECK(expr)
#define BOTAN_ASSERT(expr, assertion_made)
Application_Layer_Protocol_Notification(std::string_view protocol)
std::string single_protocol() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Encrypt_then_MAC()=default
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Extended_Master_Secret()=default
std::vector< uint8_t > serialize(Connection_Side whoami) const
void deserialize(TLS_Data_Reader &reader, Connection_Side from, Handshake_Type message_type)
std::set< Extension_Code > extension_types() const
void add(std::unique_ptr< Extension > extn)
bool contains_other_than(const std::set< Extension_Code > &allowed_extensions, bool allow_unknown_extensions=false) const
virtual bool allow_tls12() const
virtual bool allow_tls13() const
virtual bool allow_dtls12() const
bool is_datagram_protocol() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Record_Size_Limit(uint16_t limit)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Renegotiation_Extension()=default
SRTP_Protection_Profiles(const std::vector< uint16_t > &pp)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Server_Name_Indicator(std::string_view host_name)
Session_Ticket_Extension()=default
Signature_Algorithms_Cert(std::vector< Signature_Scheme > schemes)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Signature_Algorithms(std::vector< Signature_Scheme > schemes)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Supported_Groups(const std::vector< Group_Params > &groups)
std::vector< Group_Params > ec_groups() const
const std::vector< Group_Params > & groups() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< Group_Params > dh_groups() const
bool supports(Protocol_Version version) const
Supported_Versions(Protocol_Version version, const Policy &policy)
const std::vector< Protocol_Version > & versions() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::string get_string(size_t len_bytes, size_t min_bytes, size_t max_bytes)
bool has_remaining() const
void discard_next(size_t bytes)
std::vector< T > get_range(size_t len_bytes, size_t min_elems, size_t max_elems)
size_t remaining_bytes() const
std::vector< T > get_fixed(size_t size)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Unknown_Extension(Extension_Code type, TLS_Data_Reader &reader, uint16_t extension_size)
void append_tls_length_value(std::vector< uint8_t, Alloc > &buf, const T *vals, size_t vals_size, size_t tag_size)
@ CertSignatureAlgorithms
@ ApplicationLayerProtocolNegotiation
@ CertificateStatusRequest
bool value_exists(const std::vector< T > &vec, const OT &val)
const uint8_t * cast_char_ptr_to_uint8(const char *s)