10#include <botan/tls_extensions_13.h>
12#include <botan/credentials_manager.h>
13#include <botan/hash.h>
14#include <botan/tls_callbacks.h>
15#include <botan/tls_exceptn.h>
16#include <botan/tls_psk_identity_13.h>
17#include <botan/tls_session.h>
18#include <botan/tls_session_manager.h>
19#include <botan/internal/ct_utils.h>
20#include <botan/internal/stl_util.h>
21#include <botan/internal/tls_cipher_state.h>
22#include <botan/internal/tls_reader.h>
30decltype(
auto) calculate_age(std::chrono::system_clock::time_point then, std::chrono::system_clock::time_point now) {
35 return std::chrono::duration_cast<std::chrono::milliseconds>(now - then);
40 Client_PSK(Session_with_Handle& session_to_resume, std::chrono::system_clock::time_point timestamp) :
41 Client_PSK(PskIdentity(session_to_resume.handle.opaque_handle(),
42 calculate_age(session_to_resume.session.start_time(), timestamp),
43 session_to_resume.session.session_age_add()),
44 session_to_resume.session.ciphersuite().prf_algo(),
45 session_to_resume.session.extract_master_secret(),
46 Cipher_State::PSK_Type::Resumption) {}
49 explicit Client_PSK(ExternalPSK&& psk) :
52 psk.extract_master_secret(),
53 psk.is_imported() ? Cipher_State::PSK_Type::Imported : Cipher_State::PSK_Type::External) {}
55 Client_PSK(PskIdentity
id, std::vector<uint8_t> bndr) :
56 m_identity(std::move(id)), m_binder(std::move(bndr)), m_is_resumption(false) {}
58 Client_PSK(PskIdentity
id,
59 std::string_view prf_algo,
62 m_identity(std::move(id)),
77 m_binder(HashFunction::create_or_throw(prf_algo)->output_length()),
78 m_is_resumption(psk_type == Cipher_State::PSK_Type::Resumption),
80 Cipher_State::init_with_psk(
Connection_Side::
Client, psk_type, std::move(master_secret), prf_algo)) {}
82 const PskIdentity& identity()
const {
return m_identity; }
84 const std::vector<uint8_t>& binder()
const {
return m_binder; }
86 bool is_resumption()
const {
return m_is_resumption; }
88 void set_binder(std::vector<uint8_t> binder) { m_binder = std::move(binder); }
90 const Cipher_State& cipher_state()
const {
92 return *m_cipher_state;
95 std::unique_ptr<Cipher_State> take_cipher_state() {
return std::exchange(m_cipher_state,
nullptr); }
98 PskIdentity m_identity;
99 std::vector<uint8_t> m_binder;
100 bool m_is_resumption;
104 std::unique_ptr<Cipher_State> m_cipher_state;
109 explicit Server_PSK(uint16_t
id) : m_selected_identity(id), m_session_to_resume_or_psk(std::monostate()) {}
111 Server_PSK(uint16_t
id, Session session) :
112 m_selected_identity(id), m_session_to_resume_or_psk(std::move(session)) {}
114 Server_PSK(uint16_t
id, ExternalPSK psk) : m_selected_identity(id), m_session_to_resume_or_psk(std::move(psk)) {}
116 uint16_t selected_identity()
const {
return m_selected_identity; }
118 std::variant<std::monostate, Session, ExternalPSK> take_session_to_resume_or_psk() {
119 BOTAN_STATE_CHECK(!std::holds_alternative<std::monostate>(m_session_to_resume_or_psk));
120 return std::exchange(m_session_to_resume_or_psk, std::monostate());
124 uint16_t m_selected_identity;
127 std::variant<std::monostate, Session, ExternalPSK> m_session_to_resume_or_psk;
132class PSK::PSK_Internal {
134 explicit PSK_Internal(Server_PSK srv_psk) : psk(std::move(srv_psk)) {}
136 explicit PSK_Internal(std::vector<Client_PSK> clt_psks) : psk(std::move(clt_psks)) {}
139 std::variant<std::vector<Client_PSK>, Server_PSK> psk;
144 if(extension_size != 2) {
145 throw TLS_Exception(Alert::DecodeError,
"Server provided a malformed PSK extension");
149 m_impl = std::make_unique<PSK_Internal>(Server_PSK(selected_id));
152 const auto identities_offset = reader.
read_so_far();
154 std::vector<PskIdentity> psk_identities;
168 const auto obfuscated_ticket_age = reader.
get_uint32_t();
169 psk_identities.emplace_back(std::move(identity), obfuscated_ticket_age);
172 if(psk_identities.empty()) {
176 if(reader.
read_so_far() - identities_offset != identities_length) {
177 throw TLS_Exception(Alert::DecodeError,
"Inconsistent PSK identity list");
183 if(binders_length == 0) {
184 throw TLS_Exception(Alert::DecodeError,
"Empty PSK binders list");
187 std::vector<Client_PSK> psks;
188 for(
auto& psk_identity : psk_identities) {
190 throw TLS_Exception(Alert::IllegalParameter,
"Not enough PSK binders");
201 if(reader.
read_so_far() - binders_offset != binders_length) {
202 throw TLS_Exception(Alert::IllegalParameter,
"Too many PSK binders");
205 m_impl = std::make_unique<PSK_Internal>(std::move(psks));
207 throw TLS_Exception(Alert::DecodeError,
"Found a PSK extension in an unexpected handshake message");
211PSK::PSK(std::optional<Session_with_Handle>& session_to_resume, std::vector<ExternalPSK> psks,
Callbacks& callbacks) {
212 std::vector<Client_PSK> cpsk;
214 if(session_to_resume.has_value()) {
218 for(
auto&& psk : psks) {
219 cpsk.emplace_back(std::move(psk));
222 m_impl = std::make_unique<PSK_Internal>(std::move(cpsk));
226 m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(session_to_resume)))) {}
228PSK::PSK(ExternalPSK psk, uint16_t psk_index) :
229 m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(psk)))) {}
234 if(std::holds_alternative<Server_PSK>(m_impl->psk)) {
239 return std::get<std::vector<Client_PSK>>(m_impl->psk).
empty();
247 const auto id = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
248 auto& ids = std::get<std::vector<Client_PSK>>(m_impl->psk);
255 if(
id >= ids.size()) {
256 throw TLS_Exception(Alert::IllegalParameter,
"PSK identity selected by server is out of bounds");
259 auto& selected_psk = ids.at(
id);
260 auto cipher_state = selected_psk.take_cipher_state();
264 auto psk_id = [&]() -> std::optional<std::string> {
265 if(selected_psk.is_resumption()) {
268 return selected_psk.identity().identity_as_string();
279 if(!cipher_state->is_compatible_with(cipher)) {
280 throw TLS_Exception(Alert::IllegalParameter,
"PSK and ciphersuite selected by server are not compatible");
283 return {std::move(psk_id), std::move(cipher_state)};
294 auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
299 std::vector<PskIdentity> psk_identities;
301 psks.begin(), psks.end(), std::back_inserter(psk_identities), [&](
const auto& psk) { return psk.identity(); });
302 if(
auto selected_session =
304 auto& [session, psk_index] = selected_session.value();
310 if(session.server_info().hostname() == host) {
315 if(session.ciphersuite().prf_algo() != cipher.
prf_algo()) {
317 "Application chose a ticket that is not compatible with the negotiated ciphersuite");
320 return std::unique_ptr<PSK>(
new PSK(std::move(session), psk_index));
327 std::vector<std::string> psk_ids;
328 std::transform(psks.begin(), psks.end(), std::back_inserter(psk_ids), [&](
const auto& psk) {
329 return psk.identity().identity_as_string();
331 if(
auto selected_psk =
333 auto& psk = selected_psk.value();
339 if(psk.prf_algo() != cipher.
prf_algo()) {
341 "Application chose a PSK that is not compatible with the negotiated ciphersuite");
344 const auto selected_itr =
345 std::find_if(psk_identities.begin(), psk_identities.end(), [&](
const auto& offered_psk) {
346 return offered_psk.identity_as_string() == psk.identity();
348 if(selected_itr == psk_identities.end()) {
350 "Application provided a PSK with an identity that was not offered by the client");
364 return std::unique_ptr<PSK>(
365 new PSK(std::move(psk),
static_cast<uint16_t
>(std::distance(psk_identities.begin(), selected_itr))));
373 auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
375 const auto r = std::remove_if(psks.begin(), psks.end(), [&](
const auto& psk) {
376 const auto& cipher_state = psk.cipher_state();
377 return !cipher_state.is_compatible_with(cipher);
379 psks.erase(r, psks.end());
386 [](
auto out) -> std::variant<Session, ExternalPSK> {
387 if constexpr(std::is_same_v<
decltype(out), std::monostate>) {
393 std::get<Server_PSK>(m_impl->psk).take_session_to_resume_or_psk());
397 std::vector<uint8_t> result;
400 [&](
const Server_PSK& psk) {
403 const uint16_t
id = psk.selected_identity();
407 [&](
const std::vector<Client_PSK>& psks) {
410 std::vector<uint8_t> identities;
411 std::vector<uint8_t> binders;
412 for(
const auto& psk : psks) {
413 const auto& psk_identity = psk.identity();
416 const uint32_t obfuscated_ticket_age = psk_identity.obfuscated_age();
417 identities.push_back(
get_byte<0>(obfuscated_ticket_age));
418 identities.push_back(
get_byte<1>(obfuscated_ticket_age));
419 identities.push_back(
get_byte<2>(obfuscated_ticket_age));
420 identities.push_back(
get_byte<3>(obfuscated_ticket_age));
437 for(
auto& psk : std::get<std::vector<Client_PSK>>(m_impl->psk)) {
438 auto tth = truncated_transcript_hash.
clone();
439 const auto& cipher_state = psk.cipher_state();
441 psk.set_binder(cipher_state.psk_binder_mac(tth.truncated()));
449 const uint16_t index = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
450 const auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
453 const auto& expected_binder = psks[index].binder();
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_STATE_CHECK(expr)
#define BOTAN_ASSERT_NONNULL(ptr)
virtual std::optional< TLS::ExternalPSK > choose_preshared_key(std::string_view host, TLS::Connection_Side whoami, const std::vector< std::string > &identities, const std::optional< std::string > &prf=std::nullopt)
virtual std::chrono::system_clock::time_point tls_current_timestamp()
std::string prf_algo() const
std::pair< std::optional< std::string >, std::unique_ptr< Cipher_State > > take_selected_psk_info(const PSK &server_psk, const Ciphersuite &cipher)
std::variant< Session, ExternalPSK > take_session_to_resume_or_psk()
void filter(const Ciphersuite &cipher)
std::vector< uint8_t > serialize(Connection_Side side) const override
bool validate_binder(const PSK &server_psk, const std::vector< uint8_t > &binder) const
void calculate_binders(const Transcript_Hash_State &truncated_transcript_hash)
std::unique_ptr< PSK > select_offered_psk(std::string_view host, const Ciphersuite &cipher, Session_Manager &session_mgr, Credentials_Manager &credentials_mgr, Callbacks &callbacks, const Policy &policy)
bool empty() const override
PSK(TLS_Data_Reader &reader, uint16_t extension_size, Handshake_Type message_type)
virtual std::optional< std::pair< Session, uint16_t > > choose_from_offered_tickets(const std::vector< PskIdentity > &tickets, std::string_view hash_function, Callbacks &callbacks, const Policy &policy)
bool has_remaining() const
size_t read_so_far() const
std::vector< uint8_t > get_tls_length_value(size_t len_bytes)
Transcript_Hash_State clone() const
void set_algorithm(std::string_view algo_spec)
constexpr CT::Mask< T > is_equal(const T x[], const T y[], size_t len)
Strong< std::string, struct PresharedKeyID_ > PresharedKeyID
holds a PSK identity as used in TLS 1.3
void append_tls_length_value(std::vector< uint8_t, Alloc > &buf, const T *vals, size_t vals_size, size_t tag_size)
constexpr uint8_t get_byte(T input)
std::vector< T, secure_allocator< T > > secure_vector