10#include <botan/tls_extensions.h>
12#include <botan/credentials_manager.h>
13#include <botan/tls_callbacks.h>
14#include <botan/tls_exceptn.h>
15#include <botan/tls_session.h>
16#include <botan/tls_session_manager.h>
17#include <botan/internal/stl_util.h>
18#include <botan/internal/tls_cipher_state.h>
19#include <botan/internal/tls_reader.h>
24#if defined(BOTAN_HAS_TLS_13)
30decltype(
auto) calculate_age(std::chrono::system_clock::time_point then, std::chrono::system_clock::time_point now) {
35 return std::chrono::duration_cast<std::chrono::milliseconds>(now - then);
40 Client_PSK(Session_with_Handle& session_to_resume, std::chrono::system_clock::time_point timestamp) :
41 Client_PSK(PskIdentity(session_to_resume.handle.opaque_handle(),
42 calculate_age(session_to_resume.session.start_time(), timestamp),
43 session_to_resume.session.session_age_add()),
44 session_to_resume.session.ciphersuite().prf_algo(),
45 session_to_resume.session.extract_master_secret(),
46 Cipher_State::PSK_Type::Resumption) {}
48 Client_PSK(ExternalPSK&& psk) :
51 psk.extract_master_secret(),
52 Cipher_State::PSK_Type::External) {}
54 Client_PSK(PskIdentity
id, std::vector<uint8_t> bndr) :
55 m_identity(std::move(id)), m_binder(std::move(bndr)), m_is_resumption(false) {}
57 Client_PSK(PskIdentity
id,
58 std::string_view prf_algo,
59 secure_vector<uint8_t>&& master_secret,
61 m_identity(std::move(id)),
76 m_binder(HashFunction::create_or_throw(prf_algo)->output_length()),
77 m_is_resumption(psk_type == Cipher_State::PSK_Type::Resumption),
79 Cipher_State::init_with_psk(
Connection_Side::
Client, psk_type, std::move(master_secret), prf_algo)) {}
81 const PskIdentity& identity()
const {
return m_identity; }
83 const std::vector<uint8_t>& binder()
const {
return m_binder; }
85 bool is_resumption()
const {
return m_is_resumption; }
87 void set_binder(std::vector<uint8_t> binder) { m_binder = std::move(binder); }
89 const Cipher_State& cipher_state()
const {
91 return *m_cipher_state;
94 std::unique_ptr<Cipher_State> take_cipher_state() {
return std::exchange(m_cipher_state,
nullptr); }
97 PskIdentity m_identity;
98 std::vector<uint8_t> m_binder;
103 std::unique_ptr<Cipher_State> m_cipher_state;
108 Server_PSK(uint16_t
id) : m_selected_identity(id), m_session_to_resume_or_psk(std::monostate()) {}
110 Server_PSK(uint16_t
id, Session session) :
111 m_selected_identity(id), m_session_to_resume_or_psk(std::move(session)) {}
113 Server_PSK(uint16_t
id, ExternalPSK psk) : m_selected_identity(id), m_session_to_resume_or_psk(std::move(psk)) {}
115 uint16_t selected_identity()
const {
return m_selected_identity; }
117 std::variant<std::monostate, Session, ExternalPSK> take_session_to_resume_or_psk() {
118 BOTAN_STATE_CHECK(!std::holds_alternative<std::monostate>(m_session_to_resume_or_psk));
119 return std::exchange(m_session_to_resume_or_psk, std::monostate());
123 uint16_t m_selected_identity;
126 std::variant<std::monostate, Session, ExternalPSK> m_session_to_resume_or_psk;
131class PSK::PSK_Internal {
133 PSK_Internal(Server_PSK srv_psk) : psk(std::move(srv_psk)) {}
135 PSK_Internal(std::vector<Client_PSK> clt_psks) : psk(std::move(clt_psks)) {}
138 std::variant<std::vector<Client_PSK>, Server_PSK> psk;
143 if(extension_size != 2) {
144 throw TLS_Exception(Alert::DecodeError,
"Server provided a malformed PSK extension");
148 m_impl = std::make_unique<PSK_Internal>(Server_PSK(selected_id));
151 const auto identities_offset = reader.
read_so_far();
153 std::vector<PskIdentity> psk_identities;
156 const auto obfuscated_ticket_age = reader.
get_uint32_t();
157 psk_identities.emplace_back(std::move(identity), obfuscated_ticket_age);
160 if(psk_identities.empty()) {
164 if(reader.
read_so_far() - identities_offset != identities_length) {
165 throw TLS_Exception(Alert::DecodeError,
"Inconsistent PSK identity list");
171 if(binders_length == 0) {
172 throw TLS_Exception(Alert::DecodeError,
"Empty PSK binders list");
175 std::vector<Client_PSK> psks;
176 for(
auto& psk_identity : psk_identities) {
178 throw TLS_Exception(Alert::IllegalParameter,
"Not enough PSK binders");
184 if(reader.
read_so_far() - binders_offset != binders_length) {
185 throw TLS_Exception(Alert::IllegalParameter,
"Too many PSK binders");
188 m_impl = std::make_unique<PSK_Internal>(std::move(psks));
190 throw TLS_Exception(Alert::DecodeError,
"Found a PSK extension in an unexpected handshake message");
194PSK::PSK(std::optional<Session_with_Handle>& session_to_resume, std::vector<ExternalPSK> psks,
Callbacks& callbacks) {
195 std::vector<Client_PSK> cpsk;
197 if(session_to_resume.has_value()) {
201 for(
auto&& psk : psks) {
202 cpsk.emplace_back(std::move(psk));
205 m_impl = std::make_unique<PSK_Internal>(std::move(cpsk));
209 m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(session_to_resume)))) {}
211PSK::PSK(ExternalPSK psk,
const uint16_t psk_index) :
212 m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(psk)))) {}
217 if(std::holds_alternative<Server_PSK>(m_impl->psk)) {
222 return std::get<std::vector<Client_PSK>>(m_impl->psk).
empty();
230 const auto id = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
231 auto& ids = std::get<std::vector<Client_PSK>>(m_impl->psk);
238 if(
id >= ids.size()) {
239 throw TLS_Exception(Alert::IllegalParameter,
"PSK identity selected by server is out of bounds");
242 auto& selected_psk = ids.at(
id);
243 auto cipher_state = selected_psk.take_cipher_state();
247 auto psk_id = [&]() -> std::optional<std::string> {
248 if(selected_psk.is_resumption()) {
251 return selected_psk.identity().identity_as_string();
262 if(!cipher_state->is_compatible_with(cipher)) {
263 throw TLS_Exception(Alert::IllegalParameter,
"PSK and ciphersuite selected by server are not compatible");
266 return {std::move(psk_id), std::move(cipher_state)};
277 auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
282 std::vector<PskIdentity> psk_identities;
284 psks.begin(), psks.end(), std::back_inserter(psk_identities), [&](
const auto& psk) { return psk.identity(); });
285 if(
auto selected_session =
287 auto& [session, psk_index] = selected_session.value();
293 if(session.ciphersuite().prf_algo() != cipher.
prf_algo()) {
295 "Application chose a ticket that is not compatible with the negotiated ciphersuite");
298 return std::unique_ptr<PSK>(
new PSK(std::move(session), psk_index));
304 std::vector<std::string> psk_ids;
305 std::transform(psks.begin(), psks.end(), std::back_inserter(psk_ids), [&](
const auto& psk) {
306 return psk.identity().identity_as_string();
308 if(
auto selected_psk =
310 auto& psk = selected_psk.value();
316 if(psk.prf_algo() != cipher.
prf_algo()) {
318 "Application chose a PSK that is not compatible with the negotiated ciphersuite");
321 const auto selected_itr =
322 std::find_if(psk_identities.begin(), psk_identities.end(), [&](
const auto& offered_psk) {
323 return offered_psk.identity_as_string() == psk.identity();
325 if(selected_itr == psk_identities.end()) {
327 "Application provided a PSK with an identity that was not offered by the client");
341 return std::unique_ptr<PSK>(
342 new PSK(std::move(psk),
static_cast<uint16_t
>(std::distance(psk_identities.begin(), selected_itr))));
350 auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
352 const auto r = std::remove_if(psks.begin(), psks.end(), [&](
const auto& psk) {
353 const auto& cipher_state = psk.cipher_state();
354 return !cipher_state.is_compatible_with(cipher);
356 psks.erase(r, psks.end());
363 [](
auto out) -> std::variant<Session, ExternalPSK> {
364 if constexpr(std::is_same_v<
decltype(out), std::monostate>) {
370 std::get<Server_PSK>(m_impl->psk).take_session_to_resume_or_psk());
374 std::vector<uint8_t> result;
377 [&](
const Server_PSK& psk) {
380 const uint16_t
id = psk.selected_identity();
381 result.push_back(get_byte<0>(
id));
382 result.push_back(get_byte<1>(
id));
384 [&](
const std::vector<Client_PSK>& psks) {
387 std::vector<uint8_t> identities;
388 std::vector<uint8_t> binders;
389 for(
const auto& psk : psks) {
390 const auto& psk_identity = psk.identity();
393 const uint32_t obfuscated_ticket_age = psk_identity.obfuscated_age();
394 identities.push_back(get_byte<0>(obfuscated_ticket_age));
395 identities.push_back(get_byte<1>(obfuscated_ticket_age));
396 identities.push_back(get_byte<2>(obfuscated_ticket_age));
397 identities.push_back(get_byte<3>(obfuscated_ticket_age));
414 for(
auto& psk : std::get<std::vector<Client_PSK>>(m_impl->psk)) {
415 auto tth = truncated_transcript_hash.
clone();
416 const auto& cipher_state = psk.cipher_state();
418 psk.set_binder(cipher_state.psk_binder_mac(tth.truncated()));
426 const uint16_t index = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
427 const auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
430 return psks[index].binder() == binder;
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_STATE_CHECK(expr)
#define BOTAN_ASSERT_NONNULL(ptr)
virtual std::optional< TLS::ExternalPSK > choose_preshared_key(std::string_view host, TLS::Connection_Side whoami, const std::vector< std::string > &identities, const std::optional< std::string > &prf=std::nullopt)
virtual std::chrono::system_clock::time_point tls_current_timestamp()
std::string prf_algo() const
std::pair< std::optional< std::string >, std::unique_ptr< Cipher_State > > take_selected_psk_info(const PSK &server_psk, const Ciphersuite &cipher)
std::variant< Session, ExternalPSK > take_session_to_resume_or_psk()
void filter(const Ciphersuite &cipher)
std::vector< uint8_t > serialize(Connection_Side side) const override
bool validate_binder(const PSK &server_psk, const std::vector< uint8_t > &binder) const
void calculate_binders(const Transcript_Hash_State &truncated_transcript_hash)
std::unique_ptr< PSK > select_offered_psk(std::string_view host, const Ciphersuite &cipher, Session_Manager &session_mgr, Credentials_Manager &credentials_mgr, Callbacks &callbacks, const Policy &policy)
bool empty() const override
PSK(TLS_Data_Reader &reader, uint16_t extension_size, Handshake_Type message_type)
virtual std::optional< std::pair< Session, uint16_t > > choose_from_offered_tickets(const std::vector< PskIdentity > &tickets, std::string_view hash_function, Callbacks &callbacks, const Policy &policy)
bool has_remaining() const
size_t read_so_far() const
std::vector< uint8_t > get_tls_length_value(size_t len_bytes)
Transcript_Hash_State clone() const
void set_algorithm(std::string_view algo_spec)
Strong< std::string, struct PresharedKeyID_ > PresharedKeyID
holds a PSK identity as used in TLS 1.3
void append_tls_length_value(std::vector< uint8_t, Alloc > &buf, const T *vals, size_t vals_size, size_t tag_size)