10#include <botan/tls_extensions.h>
12#include <botan/credentials_manager.h>
13#include <botan/tls_callbacks.h>
14#include <botan/tls_exceptn.h>
15#include <botan/tls_session.h>
16#include <botan/tls_session_manager.h>
17#include <botan/internal/stl_util.h>
18#include <botan/internal/tls_cipher_state.h>
19#include <botan/internal/tls_reader.h>
24#if defined(BOTAN_HAS_TLS_13)
30decltype(
auto) calculate_age(std::chrono::system_clock::time_point then, std::chrono::system_clock::time_point now) {
35 return std::chrono::duration_cast<std::chrono::milliseconds>(now - then);
40 Client_PSK(Session_with_Handle& session_to_resume, std::chrono::system_clock::time_point timestamp) :
41 Client_PSK(PskIdentity(session_to_resume.handle.opaque_handle(),
42 calculate_age(session_to_resume.session.start_time(), timestamp),
43 session_to_resume.session.session_age_add()),
44 session_to_resume.session.ciphersuite().prf_algo(),
45 session_to_resume.session.extract_master_secret(),
46 Cipher_State::PSK_Type::Resumption) {}
49 Client_PSK(ExternalPSK&& psk) :
52 psk.extract_master_secret(),
53 Cipher_State::PSK_Type::External) {}
55 Client_PSK(PskIdentity
id, std::vector<uint8_t> bndr) :
56 m_identity(std::move(id)), m_binder(std::move(bndr)), m_is_resumption(false) {}
58 Client_PSK(PskIdentity
id,
59 std::string_view prf_algo,
62 m_identity(std::move(id)),
77 m_binder(HashFunction::create_or_throw(prf_algo)->output_length()),
78 m_is_resumption(psk_type == Cipher_State::PSK_Type::Resumption),
80 Cipher_State::init_with_psk(
Connection_Side::
Client, psk_type, std::move(master_secret), prf_algo)) {}
82 const PskIdentity& identity()
const {
return m_identity; }
84 const std::vector<uint8_t>& binder()
const {
return m_binder; }
86 bool is_resumption()
const {
return m_is_resumption; }
88 void set_binder(std::vector<uint8_t> binder) { m_binder = std::move(binder); }
90 const Cipher_State& cipher_state()
const {
92 return *m_cipher_state;
95 std::unique_ptr<Cipher_State> take_cipher_state() {
return std::exchange(m_cipher_state,
nullptr); }
98 PskIdentity m_identity;
99 std::vector<uint8_t> m_binder;
100 bool m_is_resumption;
104 std::unique_ptr<Cipher_State> m_cipher_state;
109 Server_PSK(uint16_t
id) : m_selected_identity(id), m_session_to_resume_or_psk(std::monostate()) {}
111 Server_PSK(uint16_t
id, Session session) :
112 m_selected_identity(id), m_session_to_resume_or_psk(std::move(session)) {}
114 Server_PSK(uint16_t
id, ExternalPSK psk) : m_selected_identity(id), m_session_to_resume_or_psk(std::move(psk)) {}
116 uint16_t selected_identity()
const {
return m_selected_identity; }
118 std::variant<std::monostate, Session, ExternalPSK> take_session_to_resume_or_psk() {
119 BOTAN_STATE_CHECK(!std::holds_alternative<std::monostate>(m_session_to_resume_or_psk));
120 return std::exchange(m_session_to_resume_or_psk, std::monostate());
124 uint16_t m_selected_identity;
127 std::variant<std::monostate, Session, ExternalPSK> m_session_to_resume_or_psk;
132class PSK::PSK_Internal {
134 PSK_Internal(Server_PSK srv_psk) : psk(std::move(srv_psk)) {}
136 PSK_Internal(std::vector<Client_PSK> clt_psks) : psk(std::move(clt_psks)) {}
139 std::variant<std::vector<Client_PSK>, Server_PSK> psk;
144 if(extension_size != 2) {
145 throw TLS_Exception(Alert::DecodeError,
"Server provided a malformed PSK extension");
149 m_impl = std::make_unique<PSK_Internal>(Server_PSK(selected_id));
152 const auto identities_offset = reader.
read_so_far();
154 std::vector<PskIdentity> psk_identities;
157 const auto obfuscated_ticket_age = reader.
get_uint32_t();
158 psk_identities.emplace_back(std::move(identity), obfuscated_ticket_age);
161 if(psk_identities.empty()) {
165 if(reader.
read_so_far() - identities_offset != identities_length) {
166 throw TLS_Exception(Alert::DecodeError,
"Inconsistent PSK identity list");
172 if(binders_length == 0) {
173 throw TLS_Exception(Alert::DecodeError,
"Empty PSK binders list");
176 std::vector<Client_PSK> psks;
177 for(
auto& psk_identity : psk_identities) {
179 throw TLS_Exception(Alert::IllegalParameter,
"Not enough PSK binders");
185 if(reader.
read_so_far() - binders_offset != binders_length) {
186 throw TLS_Exception(Alert::IllegalParameter,
"Too many PSK binders");
189 m_impl = std::make_unique<PSK_Internal>(std::move(psks));
191 throw TLS_Exception(Alert::DecodeError,
"Found a PSK extension in an unexpected handshake message");
195PSK::PSK(std::optional<Session_with_Handle>& session_to_resume, std::vector<ExternalPSK> psks,
Callbacks& callbacks) {
196 std::vector<Client_PSK> cpsk;
198 if(session_to_resume.has_value()) {
202 for(
auto&& psk : psks) {
203 cpsk.emplace_back(std::move(psk));
206 m_impl = std::make_unique<PSK_Internal>(std::move(cpsk));
210 m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(session_to_resume)))) {}
212PSK::PSK(ExternalPSK psk,
const uint16_t psk_index) :
213 m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(psk)))) {}
218 if(std::holds_alternative<Server_PSK>(m_impl->psk)) {
223 return std::get<std::vector<Client_PSK>>(m_impl->psk).
empty();
231 const auto id = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
232 auto& ids = std::get<std::vector<Client_PSK>>(m_impl->psk);
239 if(
id >= ids.size()) {
240 throw TLS_Exception(Alert::IllegalParameter,
"PSK identity selected by server is out of bounds");
243 auto& selected_psk = ids.at(
id);
244 auto cipher_state = selected_psk.take_cipher_state();
248 auto psk_id = [&]() -> std::optional<std::string> {
249 if(selected_psk.is_resumption()) {
252 return selected_psk.identity().identity_as_string();
263 if(!cipher_state->is_compatible_with(cipher)) {
264 throw TLS_Exception(Alert::IllegalParameter,
"PSK and ciphersuite selected by server are not compatible");
267 return {std::move(psk_id), std::move(cipher_state)};
278 auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
283 std::vector<PskIdentity> psk_identities;
285 psks.begin(), psks.end(), std::back_inserter(psk_identities), [&](
const auto& psk) { return psk.identity(); });
286 if(
auto selected_session =
288 auto& [session, psk_index] = selected_session.value();
294 if(session.ciphersuite().prf_algo() != cipher.
prf_algo()) {
296 "Application chose a ticket that is not compatible with the negotiated ciphersuite");
299 return std::unique_ptr<PSK>(
new PSK(std::move(session), psk_index));
305 std::vector<std::string> psk_ids;
306 std::transform(psks.begin(), psks.end(), std::back_inserter(psk_ids), [&](
const auto& psk) {
307 return psk.identity().identity_as_string();
309 if(
auto selected_psk =
311 auto& psk = selected_psk.value();
317 if(psk.prf_algo() != cipher.
prf_algo()) {
319 "Application chose a PSK that is not compatible with the negotiated ciphersuite");
322 const auto selected_itr =
323 std::find_if(psk_identities.begin(), psk_identities.end(), [&](
const auto& offered_psk) {
324 return offered_psk.identity_as_string() == psk.identity();
326 if(selected_itr == psk_identities.end()) {
328 "Application provided a PSK with an identity that was not offered by the client");
342 return std::unique_ptr<PSK>(
343 new PSK(std::move(psk),
static_cast<uint16_t
>(std::distance(psk_identities.begin(), selected_itr))));
351 auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
353 const auto r = std::remove_if(psks.begin(), psks.end(), [&](
const auto& psk) {
354 const auto& cipher_state = psk.cipher_state();
355 return !cipher_state.is_compatible_with(cipher);
357 psks.erase(r, psks.end());
364 [](
auto out) -> std::variant<Session, ExternalPSK> {
365 if constexpr(std::is_same_v<
decltype(out), std::monostate>) {
371 std::get<Server_PSK>(m_impl->psk).take_session_to_resume_or_psk());
375 std::vector<uint8_t> result;
378 [&](
const Server_PSK& psk) {
381 const uint16_t
id = psk.selected_identity();
385 [&](
const std::vector<Client_PSK>& psks) {
388 std::vector<uint8_t> identities;
389 std::vector<uint8_t> binders;
390 for(
const auto& psk : psks) {
391 const auto& psk_identity = psk.identity();
394 const uint32_t obfuscated_ticket_age = psk_identity.obfuscated_age();
395 identities.push_back(
get_byte<0>(obfuscated_ticket_age));
396 identities.push_back(
get_byte<1>(obfuscated_ticket_age));
397 identities.push_back(
get_byte<2>(obfuscated_ticket_age));
398 identities.push_back(
get_byte<3>(obfuscated_ticket_age));
415 for(
auto& psk : std::get<std::vector<Client_PSK>>(m_impl->psk)) {
416 auto tth = truncated_transcript_hash.
clone();
417 const auto& cipher_state = psk.cipher_state();
419 psk.set_binder(cipher_state.psk_binder_mac(tth.truncated()));
427 const uint16_t index = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
428 const auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
431 return psks[index].binder() == binder;
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_STATE_CHECK(expr)
#define BOTAN_ASSERT_NONNULL(ptr)
virtual std::optional< TLS::ExternalPSK > choose_preshared_key(std::string_view host, TLS::Connection_Side whoami, const std::vector< std::string > &identities, const std::optional< std::string > &prf=std::nullopt)
virtual std::chrono::system_clock::time_point tls_current_timestamp()
std::string prf_algo() const
std::pair< std::optional< std::string >, std::unique_ptr< Cipher_State > > take_selected_psk_info(const PSK &server_psk, const Ciphersuite &cipher)
std::variant< Session, ExternalPSK > take_session_to_resume_or_psk()
void filter(const Ciphersuite &cipher)
std::vector< uint8_t > serialize(Connection_Side side) const override
bool validate_binder(const PSK &server_psk, const std::vector< uint8_t > &binder) const
void calculate_binders(const Transcript_Hash_State &truncated_transcript_hash)
std::unique_ptr< PSK > select_offered_psk(std::string_view host, const Ciphersuite &cipher, Session_Manager &session_mgr, Credentials_Manager &credentials_mgr, Callbacks &callbacks, const Policy &policy)
bool empty() const override
PSK(TLS_Data_Reader &reader, uint16_t extension_size, Handshake_Type message_type)
virtual std::optional< std::pair< Session, uint16_t > > choose_from_offered_tickets(const std::vector< PskIdentity > &tickets, std::string_view hash_function, Callbacks &callbacks, const Policy &policy)
bool has_remaining() const
size_t read_so_far() const
std::vector< uint8_t > get_tls_length_value(size_t len_bytes)
Transcript_Hash_State clone() const
void set_algorithm(std::string_view algo_spec)
Strong< std::string, struct PresharedKeyID_ > PresharedKeyID
holds a PSK identity as used in TLS 1.3
void append_tls_length_value(std::vector< uint8_t, Alloc > &buf, const T *vals, size_t vals_size, size_t tag_size)
constexpr uint8_t get_byte(T input)
std::vector< T, secure_allocator< T > > secure_vector