13#include <botan/tls_extensions.h>
15#include <botan/tls_exceptn.h>
16#include <botan/tls_policy.h>
17#include <botan/internal/fmt.h>
18#include <botan/internal/parsing.h>
19#include <botan/internal/stl_util.h>
20#include <botan/internal/tls_reader.h>
22#include <unordered_set>
24#if defined(BOTAN_HAS_TLS_13)
25 #include <botan/tls_extensions_13.h>
28#if defined(BOTAN_HAS_TLS_12)
29 #include <botan/tls_extensions_12.h>
42 const uint16_t size =
static_cast<uint16_t
>(reader.remaining_bytes());
45 return std::make_unique<Server_Name_Indicator>(reader, size, from);
48 return std::make_unique<Supported_Groups>(reader, size);
51 return std::make_unique<Certificate_Status_Request>(reader, size, message_type, from);
54 return std::make_unique<Signature_Algorithms>(reader, size);
57 return std::make_unique<Signature_Algorithms_Cert>(reader, size);
60 return std::make_unique<SRTP_Protection_Profiles>(reader, size);
63 return std::make_unique<Application_Layer_Protocol_Notification>(reader, size, from);
66 return std::make_unique<Client_Certificate_Type>(reader, size, from);
69 return std::make_unique<Server_Certificate_Type>(reader, size, from);
72 return std::make_unique<Record_Size_Limit>(reader, size, from);
75 return std::make_unique<Supported_Versions>(reader, size, from);
80#if defined(BOTAN_HAS_TLS_12)
82 return std::make_unique<Supported_Point_Formats>(reader, size);
85 return std::make_unique<Renegotiation_Extension>(reader, size);
88 return std::make_unique<Extended_Master_Secret>(reader, size);
91 return std::make_unique<Encrypt_then_MAC>(reader, size);
94 return std::make_unique<Session_Ticket_Extension>(reader, size, from);
104#if defined(BOTAN_HAS_TLS_13)
106 return std::make_unique<PSK>(reader, size, message_type);
109 return std::make_unique<EarlyDataIndication>(reader, size, message_type);
112 return std::make_unique<Cookie>(reader, size);
115 return std::make_unique<PSK_Key_Exchange_Modes>(reader, size);
118 return std::make_unique<Certificate_Authorities>(reader, size);
121 return std::make_unique<Key_Share>(reader, size, message_type);
133 return std::make_unique<Unknown_Extension>(code, reader, size);
141 return m_extensions.contains(type);
145 const auto i = m_extensions.find(type);
147 if(i == m_extensions.end()) {
150 return i->second.get();
155 const auto type = extn->
type();
157 throw Invalid_Argument(
"cannot add the same extension twice: " + std::to_string(
static_cast<uint16_t
>(type)));
160 m_extension_codes.push_back(type);
161 m_extensions.emplace(type, std::move(extn));
179 throw TLS_Exception(TLS::Alert::DecodeError,
"Peer sent duplicated extensions");
184 const std::vector<uint8_t> extn_data = reader.
get_fixed<uint8_t>(extension_size);
185 m_raw_extension_data[
type] = extn_data;
187 this->
add(make_extension(extn_reader,
type, from, message_type));
188 extn_reader.assert_done();
194 const bool allow_unknown_extensions)
const {
197 std::vector<Extension_Code> diff;
199 found.cbegin(), found.end(), allowed_extensions.cbegin(), allowed_extensions.cend(), std::back_inserter(diff));
201 if(allow_unknown_extensions) {
204 const auto itr = std::find_if(diff.cbegin(), diff.cend(), [
this](
const auto ext_type) {
205 const auto ext = get(ext_type);
206 return ext && ext->is_implemented();
210 return itr != diff.cend();
213 return !diff.empty();
217 auto i = m_extensions.find(type);
219 if(i == m_extensions.end()) {
222 m_extensions.erase(i);
223 std::erase(m_extension_codes, type);
224 m_raw_extension_data.erase(type);
230 std::vector<uint8_t> buf(2);
233 for(
const auto extn_type : m_extension_codes) {
234 const auto& extn = m_extensions.at(extn_type);
240 const uint16_t extn_code =
static_cast<uint16_t
>(extn_type);
242 const std::vector<uint8_t> extn_val = extn->serialize(whoami);
250 buf.push_back(
get_byte<0>(
static_cast<uint16_t
>(extn_val.size())));
251 buf.push_back(
get_byte<1>(
static_cast<uint16_t
>(extn_val.size())));
258 const uint16_t extn_size =
static_cast<uint16_t
>(buf.size() - 2);
264 if(buf.size() == 2) {
265 return std::vector<uint8_t>();
272 std::set<Extension_Code> offers;
273 for(
const auto& [extn_type, extn] : m_extensions) {
277 offers.insert(extn_type);
284 const std::set<Extension_Code> in_order(order.begin(), order.end());
286 std::vector<Extension_Code> new_codes;
287 new_codes.reserve(m_extension_codes.size());
290 for(
auto code : m_extension_codes) {
291 if(!in_order.contains(code)) {
292 new_codes.push_back(code);
300 std::unordered_set<Extension_Code> already_pushed;
301 for(
auto code : order) {
302 if(m_extensions.contains(code) && already_pushed.insert(code).second) {
303 new_codes.push_back(code);
307 m_extension_codes = std::move(new_codes);
311 m_type(
type), m_value(reader.get_fixed<uint8_t>(extension_size)) {}
330 if(extension_size != 0) {
331 throw TLS_Exception(Alert::IllegalParameter,
"Server sent non-empty SNI extension");
335 if(extension_size == 0) {
336 throw TLS_Exception(Alert::IllegalParameter,
"Client sent empty SNI extension");
344 if(name_bytes + 2 != extension_size || name_bytes < 4) {
351 const uint8_t name_type = reader.
get_byte();
358 if(!m_sni_host_name.empty()) {
359 throw Decoding_Error(
"TLS ServerNameIndicator contains more than one host_name");
361 m_sni_host_name = reader.
get_string(2, 1, 65535);
370 const uint16_t unknown_name_len = reader.
get_uint16_t();
386 std::vector<uint8_t> buf;
388 const size_t name_len = m_sni_host_name.size();
395 buf.push_back(
get_byte<0>(
static_cast<uint16_t
>(name_len + 3)));
396 buf.push_back(
get_byte<1>(
static_cast<uint16_t
>(name_len + 3)));
399 buf.push_back(
get_byte<0>(
static_cast<uint16_t
>(name_len)));
400 buf.push_back(
get_byte<1>(
static_cast<uint16_t
>(name_len)));
410 if(hostname.empty()) {
419 if(hostname.find(
':') != std::string_view::npos) {
427 BOTAN_ARG_CHECK(!protocol.empty(),
"ALPN protocol name must not be empty");
429 m_protocols.emplace_back(protocol);
433 const std::vector<std::string>&
protocols) :
436 BOTAN_ARG_CHECK(!protocol.empty(),
"ALPN protocol name must not be empty");
442 uint16_t extension_size,
444 if(extension_size < 2) {
450 size_t bytes_remaining = extension_size - 2;
452 if(name_bytes != bytes_remaining) {
453 throw Decoding_Error(
"Bad encoding of ALPN extension, bad length field");
457 if(name_bytes == 0) {
458 throw Decoding_Error(
"Empty ALPN protocol_name_list not allowed");
461 while(bytes_remaining > 0) {
462 const std::string p = reader.
get_string(1, 0, 255);
464 if(bytes_remaining < p.size() + 1) {
465 throw Decoding_Error(
"Bad encoding of ALPN, length field too long");
472 bytes_remaining -= (p.size() + 1);
474 m_protocols.push_back(p);
484 "Server sent " + std::to_string(m_protocols.size()) +
" protocols in ALPN extension response");
490 return m_protocols.front();
494 std::vector<uint8_t> buf(2);
496 for(
auto&& proto : m_protocols) {
497 if(proto.length() >= 256) {
498 throw TLS_Exception(Alert::InternalError,
"ALPN name too long");
507 buf[0] =
get_byte<0>(
static_cast<uint16_t
>(buf.size() - 2));
508 buf[1] =
get_byte<1>(
static_cast<uint16_t
>(buf.size() - 2));
515 BOTAN_ARG_CHECK(!m_certificate_types.empty(),
"at least one certificate type must be supported");
525 const std::vector<Certificate_Type>& server_preference) :
533 for(
const auto server_supported_cert_type : server_preference) {
534 if(
value_exists(certificate_type_from_client.m_certificate_types, server_supported_cert_type)) {
535 m_certificate_types.push_back(server_supported_cert_type);
545 throw TLS_Exception(Alert::UnsupportedCertificate,
"Failed to agree on certificate_type");
550 if(extension_size == 0) {
551 throw Decoding_Error(
"Certificate type extension cannot be empty");
556 if(
static_cast<size_t>(extension_size) != type_bytes.size() + 1) {
557 throw Decoding_Error(
"certificate type extension had inconsistent length");
560 if(type_bytes.empty()) {
561 throw Decoding_Error(
"Certificate type extension contains no types");
564 type_bytes.begin(), type_bytes.end(), std::back_inserter(m_certificate_types), [](
const auto type_byte) {
565 return static_cast<Certificate_Type>(type_byte);
571 if(extension_size != 1) {
572 throw Decoding_Error(
"Server's certificate type extension must be of length 1");
574 const auto type_byte = reader.
get_byte();
580 std::vector<uint8_t> result;
582 std::vector<uint8_t> type_bytes;
584 m_certificate_types.begin(), m_certificate_types.end(), std::back_inserter(type_bytes), [](
const auto type) {
585 return static_cast<uint8_t>(type);
590 result.push_back(
static_cast<uint8_t
>(m_certificate_types.front()));
605 Botan::fmt(
"Selected certificate type was not offered: {}",
613 return m_certificate_types.front();
623 std::vector<Group_Params> ec;
624 for(
auto g : m_groups) {
625 if(g.is_pure_ecc_group()) {
633 std::vector<Group_Params> dh;
634 for(
auto g : m_groups) {
635 if(g.is_in_ffdhe_range()) {
643 std::vector<uint8_t> buf(2);
645 for(
auto g : m_groups) {
646 const uint16_t
id = g.wire_code();
656 buf[0] =
get_byte<0>(
static_cast<uint16_t
>(buf.size() - 2));
657 buf[1] =
get_byte<1>(
static_cast<uint16_t
>(buf.size() - 2));
665 if(len + 2 != extension_size) {
666 throw Decoding_Error(
"Inconsistent length field in supported groups list");
678 const size_t elems = len / 2;
680 std::unordered_set<uint16_t> seen;
681 for(
size_t i = 0; i != elems; ++i) {
684 if(seen.insert(group.wire_code()).second) {
685 m_groups.push_back(group);
692std::vector<uint8_t> serialize_signature_algorithms(
const std::vector<Signature_Scheme>& schemes) {
693 BOTAN_ASSERT(schemes.size() < 256,
"Too many signature schemes");
695 std::vector<uint8_t> buf;
697 const uint16_t len =
static_cast<uint16_t
>(schemes.size() * 2);
710std::vector<Signature_Scheme> parse_signature_algorithms(TLS_Data_Reader& reader, uint16_t extension_size) {
711 uint16_t len = reader.get_uint16_t();
713 if(len + 2 != extension_size || len % 2 == 1 || len == 0) {
714 throw Decoding_Error(
"Bad encoding on signature algorithms extension");
717 std::vector<Signature_Scheme> schemes;
718 schemes.reserve(len / 2);
720 schemes.emplace_back(reader.get_uint16_t());
730 return serialize_signature_algorithms(m_schemes);
734 m_schemes(parse_signature_algorithms(reader, extension_size)) {}
737 return serialize_signature_algorithms(m_schemes);
741 m_schemes(parse_signature_algorithms(reader, extension_size)) {}
749 if(extension_size < 5) {
752 const size_t max_profile_pairs = (
static_cast<size_t>(extension_size) - 3) / 2;
753 m_pp = reader.
get_range<uint16_t>(2, 1, max_profile_pairs);
754 const std::vector<uint8_t> mki = reader.
get_range<uint8_t>(1, 0, 255);
756 if(m_pp.size() * 2 + mki.size() + 3 != extension_size) {
757 throw Decoding_Error(
"Bad encoding for SRTP protection extension");
761 throw Decoding_Error(
"Unhandled non-empty MKI for SRTP protection extension");
766 std::vector<uint8_t> buf;
768 const uint16_t pp_len =
static_cast<uint16_t
>(m_pp.size() * 2);
772 for(
const uint16_t pp : m_pp) {
783 std::vector<uint8_t> buf;
787 buf.push_back(m_versions[0].major_version());
788 buf.push_back(m_versions[0].minor_version());
793 const uint8_t len =
static_cast<uint8_t
>(m_versions.size() * 2);
798 buf.push_back(version.major_version());
799 buf.push_back(version.minor_version());
814#if defined(BOTAN_HAS_TLS_13)
816 if(offer >= Protocol_Version::TLS_V13 && policy.
allow_tls13()) {
817 m_versions.push_back(Protocol_Version::TLS_V13);
822#if defined(BOTAN_HAS_TLS_12)
824 if(offer >= Protocol_Version::DTLS_V12 && policy.
allow_dtls12()) {
825 m_versions.push_back(Protocol_Version::DTLS_V12);
828 if(offer >= Protocol_Version::TLS_V12 && policy.
allow_tls12()) {
829 m_versions.push_back(Protocol_Version::TLS_V12);
840 if(extension_size != 2) {
841 throw Decoding_Error(
"Server sent invalid supported_versions extension");
851 if(extension_size != 1 + 2 *
versions.size()) {
852 throw Decoding_Error(
"Client sent invalid supported_versions extension");
858 for(
auto v : m_versions) {
867 BOTAN_ARG_CHECK(
limit >= 64,
"RFC 8449 does not allow record size limits smaller than 64 bytes");
869 "RFC 8449 does not allow record size limits larger than 2^14+1");
873 if(extension_size != 2) {
874 throw TLS_Exception(Alert::DecodeError,
"invalid record_size_limit extension");
894 "Server requested a record size limit larger than the protocol's maximum");
902 throw TLS_Exception(Alert::IllegalParameter,
"Received a record size limit smaller than 64 bytes");
907 std::vector<uint8_t> buf;
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_STATE_CHECK(expr)
#define BOTAN_ARG_CHECK(expr, msg)
#define BOTAN_ASSERT(expr, assertion_made)
Application_Layer_Protocol_Notification(std::string_view protocol)
const std::vector< std::string > & protocols() const
std::string single_protocol() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Certificate_Type selected_certificate_type() const
Certificate_Type_Base(std::vector< Certificate_Type > supported_cert_types)
void validate_selection(const Certificate_Type_Base &from_server) const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Extension_Code type() const override
Client_Certificate_Type(const Client_Certificate_Type &cct, const Policy &policy)
Certificate_Type_Base(std::vector< Certificate_Type > supported_cert_types)
virtual Extension_Code type() const =0
std::vector< uint8_t > serialize(Connection_Side whoami) const
void reorder(const std::vector< Extension_Code > &order)
void deserialize(TLS_Data_Reader &reader, Connection_Side from, Handshake_Type message_type)
bool remove_extension(Extension_Code type)
std::set< Extension_Code > extension_types() const
void add(std::unique_ptr< Extension > extn)
bool contains_other_than(const std::set< Extension_Code > &allowed_extensions, bool allow_unknown_extensions=false) const
virtual bool allow_tls12() const
virtual bool allow_tls13() const
virtual bool allow_dtls12() const
bool is_datagram_protocol() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Record_Size_Limit(uint16_t limit)
SRTP_Protection_Profiles(const std::vector< uint16_t > &pp)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Server_Certificate_Type(const Server_Certificate_Type &sct, const Policy &policy)
Certificate_Type_Base(std::vector< Certificate_Type > supported_cert_types)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
static bool hostname_acceptable_for_sni(std::string_view hostname)
Server_Name_Indicator(std::string_view host_name)
Signature_Algorithms_Cert(std::vector< Signature_Scheme > schemes)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Signature_Algorithms(std::vector< Signature_Scheme > schemes)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Supported_Groups(const std::vector< Group_Params > &groups)
std::vector< Group_Params > ec_groups() const
const std::vector< Group_Params > & groups() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::vector< Group_Params > dh_groups() const
bool supports(Protocol_Version version) const
Supported_Versions(Protocol_Version version, const Policy &policy)
const std::vector< Protocol_Version > & versions() const
std::vector< uint8_t > serialize(Connection_Side whoami) const override
std::string get_string(size_t len_bytes, size_t min_bytes, size_t max_bytes)
bool has_remaining() const
void discard_next(size_t bytes)
std::vector< T > get_range(size_t len_bytes, size_t min_elems, size_t max_elems)
size_t remaining_bytes() const
std::vector< uint8_t > get_tls_length_value(size_t len_bytes)
std::vector< T > get_fixed(size_t size)
std::vector< uint8_t > serialize(Connection_Side whoami) const override
Unknown_Extension(Extension_Code type, TLS_Data_Reader &reader, uint16_t extension_size)
Extension_Code type() const override
std::string certificate_type_to_string(Certificate_Type type)
void append_tls_length_value(std::vector< uint8_t, Alloc > &buf, const T *vals, size_t vals_size, size_t tag_size)
@ CertSignatureAlgorithms
@ ApplicationLayerProtocolNegotiation
@ CertificateStatusRequest
constexpr uint8_t get_byte(T input)
std::span< const uint8_t > as_span_of_bytes(const char *s, size_t len)
std::string fmt(std::string_view format, const T &... args)
bool value_exists(const std::vector< T > &vec, const V &val)
std::optional< uint32_t > string_to_ipv4(std::string_view str)