8#include <botan/tls_session_manager_sql.h>
10#include <botan/database.h>
12#include <botan/pwdhash.h>
14#include <botan/internal/loadstor.h>
20 std::string_view passphrase,
21 const std::shared_ptr<RandomNumberGenerator>& rng,
22 size_t max_sessions) :
23 Session_Manager(rng), m_db(std::move(db)), m_max_sessions(max_sessions) {
24 create_or_migrate_and_open(passphrase);
27void Session_Manager_SQL::create_or_migrate_and_open(std::string_view passphrase) {
28 switch(detect_schema_revision()) {
34 m_db->exec(
"DROP TABLE IF EXISTS tls_sessions");
35 m_db->exec(
"DROP TABLE IF EXISTS tls_sessions_metadata");
36 create_with_latest_schema(passphrase, BOTAN_3_0);
39 initialize_existing_database(passphrase);
42 throw Internal_Error(
"TLS session db has unknown database schema");
46Session_Manager_SQL::Schema_Revision Session_Manager_SQL::detect_schema_revision() {
48 const auto meta_data_rows = m_db->row_count(
"tls_sessions_metadata");
49 if(meta_data_rows != 1) {
52 }
catch(
const SQL_Database::SQL_DB_Error&) {
57 auto stmt = m_db->new_statement(
"SELECT database_revision FROM tls_sessions_metadata");
59 throw Internal_Error(
"Failed to read revision of TLS session database");
61 return Schema_Revision(stmt->get_size_t(0));
62 }
catch(
const SQL_Database::SQL_DB_Error&) {
67void Session_Manager_SQL::create_with_latest_schema(std::string_view passphrase, Schema_Revision rev) {
69 "CREATE TABLE tls_sessions "
71 "session_id TEXT PRIMARY KEY, "
72 "session_ticket BLOB, "
73 "session_start INTEGER, "
76 "session BLOB NOT NULL"
80 "CREATE TABLE tls_sessions_metadata "
82 "passphrase_salt BLOB, "
83 "passphrase_iterations INTEGER, "
84 "passphrase_check INTEGER, "
85 "password_hash_family TEXT, "
86 "database_revision INTEGER"
90 m_db->create_table(
"CREATE INDEX tls_tickets ON tls_sessions (session_ticket)");
92 auto salt =
m_rng->random_vec<std::vector<uint8_t>>(16);
96 const auto pbkdf_name =
"PBKDF2(SHA-512)";
99 auto desired_runtime = std::chrono::milliseconds(100);
100 auto pbkdf = pbkdf_fam->tune(derived_key.size(), desired_runtime);
103 derived_key.data(), derived_key.size(), passphrase.data(), passphrase.size(), salt.data(), salt.size());
105 const size_t iterations = pbkdf->iterations();
106 const size_t check_val =
make_uint16(derived_key[0], derived_key[1]);
107 m_session_key =
SymmetricKey(std::span(derived_key).subspan(2));
109 auto stmt = m_db->new_statement(
"INSERT INTO tls_sessions_metadata VALUES (?1, ?2, ?3, ?4, ?5)");
112 stmt->bind(2, iterations);
113 stmt->bind(3, check_val);
114 stmt->bind(4, pbkdf_name);
120void Session_Manager_SQL::initialize_existing_database(std::string_view passphrase) {
121 auto stmt = m_db->new_statement(
"SELECT * FROM tls_sessions_metadata");
123 throw Internal_Error(
"Failed to initialize TLS session database");
126 std::pair<const uint8_t*, size_t> salt = stmt->get_blob(0);
127 const size_t iterations = stmt->get_size_t(1);
128 const size_t check_val_db = stmt->get_size_t(2);
129 const std::string pbkdf_name = stmt->get_str(3);
134 auto pbkdf = pbkdf_fam->from_params(iterations);
137 derived_key.data(), derived_key.size(), passphrase.data(), passphrase.size(), salt.first, salt.second);
139 const size_t check_val_created =
make_uint16(derived_key[0], derived_key[1]);
141 if(check_val_created != check_val_db) {
142 throw Invalid_Argument(
"Session database password not valid");
145 m_session_key =
SymmetricKey(std::span(derived_key).subspan(2));
149 std::optional<lock_guard_type<recursive_mutex_type>> lk;
158 auto stmt = m_db->new_statement(
159 "INSERT OR REPLACE INTO tls_sessions"
160 " VALUES (?1, ?2, ?3, ?4, ?5, ?6)");
168 stmt->bind(2, ticket.get());
176 prune_session_cache();
180 std::optional<lock_guard_type<recursive_mutex_type>> lk;
185 if(
auto session_id = handle.
id()) {
186 auto stmt = m_db->new_statement(
"SELECT session FROM tls_sessions WHERE session_id = ?1");
190 while(stmt->step()) {
191 std::pair<const uint8_t*, size_t> blob = stmt->get_blob(0);
203 const size_t max_sessions_hint) {
204 std::optional<lock_guard_type<recursive_mutex_type>> lk;
209 auto stmt = m_db->new_statement(
210 "SELECT session_id, session_ticket, session FROM tls_sessions"
211 " WHERE hostname = ?1 AND hostport = ?2"
212 " ORDER BY session_start DESC"
216 stmt->bind(2, info.
port());
217 stmt->bind(3, max_sessions_hint);
219 std::vector<Session_with_Handle> found_sessions;
220 while(stmt->step()) {
222 auto ticket_blob = stmt->get_blob(1);
223 if(ticket_blob.second > 0) {
224 return Session_Ticket(std::span(ticket_blob.first, ticket_blob.second));
230 std::pair<const uint8_t*, size_t> blob = stmt->get_blob(2);
233 found_sessions.emplace_back(
238 return found_sessions;
246 if(
const auto id = handle.
id()) {
247 auto stmt = m_db->new_statement(
"DELETE FROM tls_sessions WHERE session_id = ?1");
250 }
else if(
const auto ticket = handle.
ticket()) {
251 auto stmt = m_db->new_statement(
"DELETE FROM tls_sessions WHERE session_ticket = ?1");
252 stmt->bind(1, ticket->get());
256 throw Invalid_Argument(
"provided a session handle that is neither ID nor ticket");
259 return m_db->rows_changed_by_last_statement();
267 m_db->exec(
"DELETE FROM tls_sessions");
268 return m_db->rows_changed_by_last_statement();
271void Session_Manager_SQL::prune_session_cache() {
274 if(m_max_sessions == 0) {
278 auto remove_oldest = m_db->new_statement(
279 "DELETE FROM tls_sessions WHERE session_id NOT IN "
280 "(SELECT session_id FROM tls_sessions ORDER BY session_start DESC LIMIT ?1)");
281 remove_oldest->bind(1, m_max_sessions);
282 remove_oldest->spin();
static std::unique_ptr< PasswordHashFamily > create_or_throw(std::string_view algo_spec, std::string_view provider="")
std::chrono::system_clock::time_point start_time() const
const Server_Information & server_info() const
Helper class to embody a session handle in all protocol versions.
std::optional< Session_Ticket > ticket() const
std::optional< Session_ID > id() const
Session_Manager_SQL(std::shared_ptr< SQL_Database > db, std::string_view passphrase, const std::shared_ptr< RandomNumberGenerator > &rng, size_t max_sessions=1000)
void store(const Session &session, const Session_Handle &handle) override
Save a Session under a Session_Handle (TLS Client)
size_t remove(const Session_Handle &handle) override
std::vector< Session_with_Handle > find_some(const Server_Information &info, size_t max_sessions_hint) override
Internal retrieval function to find sessions to resume.
virtual bool database_is_threadsafe() const
std::optional< Session > retrieve_one(const Session_Handle &handle) override
Internal retrieval function for a single session.
size_t remove_all() override
recursive_mutex_type & mutex()
std::shared_ptr< RandomNumberGenerator > m_rng
std::vector< uint8_t > encrypt(const SymmetricKey &key, RandomNumberGenerator &rng) const
static Session decrypt(const uint8_t ctext[], size_t ctext_size, const SymmetricKey &key)
Strong< std::vector< uint8_t >, struct Session_ID_ > Session_ID
holds a TLS 1.2 session ID for stateful resumption
Strong< std::vector< uint8_t >, struct Session_Ticket_ > Session_Ticket
holds a TLS 1.2 session ticket for stateless resumption
void hex_encode(char output[], const uint8_t input[], size_t input_length, bool uppercase)
size_t hex_decode(uint8_t output[], const char input[], size_t input_length, size_t &input_consumed, bool ignore_ws)
std::vector< T, secure_allocator< T > > secure_vector
constexpr uint16_t make_uint16(uint8_t i0, uint8_t i1)