19#include <botan/xmss.h>
21#include <botan/ber_dec.h>
22#include <botan/der_enc.h>
24#include <botan/internal/buffer_slicer.h>
25#include <botan/internal/concat_util.h>
26#include <botan/internal/int_utils.h>
27#include <botan/internal/loadstor.h>
28#include <botan/internal/stateful_key_index_registry.h>
29#include <botan/internal/xmss_common_ops.h>
30#include <botan/internal/xmss_hash.h>
31#include <botan/internal/xmss_signature_operation.h>
33#if defined(BOTAN_HAS_THREAD_UTILS)
34 #include <botan/internal/thread_pool.h>
47 if(key_bits.size() == xmss_params.raw_private_key_size() ||
48 key_bits.size() == xmss_params.raw_legacy_private_key_size()) {
49 raw_key.assign(key_bits.begin(), key_bits.end());
59class XMSS_PrivateKey_Internal {
61 XMSS_PrivateKey_Internal(
const XMSS_Parameters& xmss_params,
62 const XMSS_WOTS_Parameters& wots_params,
64 RandomNumberGenerator& rng) :
65 m_xmss_params(xmss_params),
66 m_wots_params(wots_params),
67 m_wots_derivation_method(wots_derivation_method),
68 m_prf(rng.random_vec(xmss_params.element_size())),
69 m_private_seed(rng.random_vec(xmss_params.element_size())),
70 m_keyid(Stateful_Key_Index_Registry::KeyId(
"XMSS", m_xmss_params.oid(), m_private_seed, m_prf)) {}
72 XMSS_PrivateKey_Internal(
const XMSS_Parameters& xmss_params,
73 const XMSS_WOTS_Parameters& wots_params,
77 m_xmss_params(xmss_params),
78 m_wots_params(wots_params),
79 m_wots_derivation_method(wots_derivation_method),
80 m_prf(std::move(prf)),
81 m_private_seed(std::move(private_seed)),
82 m_keyid(Stateful_Key_Index_Registry::KeyId(
"XMSS", m_xmss_params.oid(), m_private_seed, m_prf)) {}
84 XMSS_PrivateKey_Internal(
const XMSS_Parameters& xmss_params,
85 const XMSS_WOTS_Parameters& wots_params,
86 std::span<const uint8_t> key_bits) :
87 m_xmss_params(xmss_params), m_wots_params(wots_params), m_keyid() {
96 static_assert(
sizeof(size_t) >= 4,
"size_t is big enough to support leaf index");
100 if(raw_key.size() != m_xmss_params.raw_private_key_size() &&
101 raw_key.size() != m_xmss_params.raw_legacy_private_key_size()) {
102 throw Decoding_Error(
"Invalid XMSS private key size");
105 BufferSlicer s(raw_key);
108 s.skip(m_xmss_params.raw_public_key_size());
110 auto unused_leaf_bytes = s.take(
sizeof(uint32_t));
112 if(unused_leaf >= (1ULL << m_xmss_params.tree_height())) {
113 throw Decoding_Error(
"XMSS private key leaf index out of bounds");
116 m_prf = s.copy_as_secure_vector(m_xmss_params.element_size());
117 m_private_seed = s.copy_as_secure_vector(m_xmss_params.element_size());
119 m_keyid = Stateful_Key_Index_Registry::KeyId(
"XMSS", m_xmss_params.oid(), m_private_seed, m_prf);
122 set_unused_leaf_index(unused_leaf);
126 m_wots_derivation_method =
133 std::vector<uint8_t> unused_index(4);
134 store_be(
static_cast<uint32_t
>(unused_leaf_index()), unused_index.data());
136 std::vector<uint8_t> wots_derivation_method;
137 wots_derivation_method.push_back(
static_cast<uint8_t
>(m_wots_derivation_method));
140 raw_public_key, unused_index, m_prf, m_private_seed, wots_derivation_method);
147 const XMSS_WOTS_Parameters& wots_parameters() {
return m_wots_params; }
151 void set_unused_leaf_index(
size_t idx) {
152 if(idx >= (1ULL << m_xmss_params.tree_height())) {
153 throw Decoding_Error(
"XMSS private key leaf index out of bounds");
159 size_t reserve_unused_leaf_index() {
161 if(idx >= m_xmss_params.total_number_of_signatures()) {
162 throw Decoding_Error(
"XMSS private key, one time signatures exhausted");
165 return static_cast<size_t>(idx);
168 size_t unused_leaf_index()
const {
173 uint64_t remaining_signatures()
const {
174 const size_t max = m_xmss_params.total_number_of_signatures();
179 XMSS_Parameters m_xmss_params;
180 XMSS_WOTS_Parameters m_wots_params;
185 Stateful_Key_Index_Registry::KeyId m_keyid;
210 m_private(std::make_shared<XMSS_PrivateKey_Internal>(
212 m_private->set_unused_leaf_index(idx_leaf);
214 "XMSS: unexpected byte length of PRF value");
216 "XMSS: unexpected byte length of private seed");
220 size_t target_node_height,
224 BOTAN_ASSERT((start_idx % (
static_cast<size_t>(1) << target_node_height)) == 0,
225 "Start index must be divisible by 2^{target node height}.");
227#if defined(BOTAN_HAS_THREAD_UTILS)
232 const size_t split_level = std::min(target_node_height, thread_pool.
worker_count());
235 if(split_level == 0) {
238 tree_hash_subtree(result, start_idx, target_node_height, subtree_addr, hash);
242 const size_t subtrees =
static_cast<size_t>(1) << split_level;
243 const size_t last_idx = (
static_cast<size_t>(1) << (target_node_height)) + start_idx;
244 const size_t offs = (last_idx - start_idx) / subtrees;
246 uint8_t level =
static_cast<uint8_t
>(split_level);
249 "Number of worker threads in tree_hash need to divide range "
250 "of calculated nodes.");
252 std::vector<secure_vector<uint8_t>> nodes(subtrees,
254 std::vector<XMSS_Address> node_addresses(subtrees, adrs);
255 std::vector<XMSS_Hash> xmss_hash(subtrees, hash);
256 std::vector<std::future<void>> work;
259 for(
size_t i = 0; i < subtrees; i++) {
260 using tree_hash_subtree_fn_t =
263 const tree_hash_subtree_fn_t work_fn = &XMSS_PrivateKey::tree_hash_subtree;
265 work.push_back(thread_pool.
run(work_fn,
268 start_idx + i * offs,
269 target_node_height - split_level,
270 std::ref(node_addresses[i]),
271 std::ref(xmss_hash[i])));
274 for(
auto& w : work) {
281 std::vector<secure_vector<uint8_t>> ro_nodes(nodes.begin(),
282 nodes.begin() + (
static_cast<size_t>(1) << (level + 1)));
284 for(
size_t i = 0; i < (static_cast<size_t>(1) << level); i++) {
287 node_addresses[i].set_tree_height(
static_cast<uint32_t
>(target_node_height - (level + 1)));
288 node_addresses[i].set_tree_index((node_addresses[2 * i + 1].get_tree_index() - 1) >> 1);
292 std::cref(ro_nodes[2 * i]),
293 std::cref(ro_nodes[2 * i + 1]),
295 std::cref(this->public_seed()),
296 std::ref(xmss_hash[i]),
300 for(
auto& w : work) {
307 node_addresses[0].set_tree_height(
static_cast<uint32_t
>(target_node_height - 1));
308 node_addresses[0].set_tree_index((node_addresses[1].get_tree_index() - 1) >> 1);
314 XMSS_Address subtree_addr(adrs);
315 tree_hash_subtree(result, start_idx, target_node_height, subtree_addr, hash);
322 size_t target_node_height,
327 std::vector<secure_vector<uint8_t>> nodes(target_node_height + 1,
334 std::vector<uint8_t> node_levels(target_node_height + 1);
337 const size_t last_idx = (
static_cast<size_t>(1) << target_node_height) + start_idx;
339 for(
size_t i = start_idx; i < last_idx; i++) {
341 adrs.set_ots_address(
static_cast<uint32_t
>(i));
343 const XMSS_WOTS_PublicKey pk = this->wots_public_key_for(adrs, hash);
346 adrs.set_ltree_address(
static_cast<uint32_t
>(i));
348 node_levels[level] = 0;
351 adrs.set_tree_height(0);
352 adrs.set_tree_index(
static_cast<uint32_t
>(i));
354 while(level > 0 && node_levels[level] == node_levels[level - 1]) {
355 adrs.set_tree_index(((adrs.get_tree_index() - 1) >> 1));
357 nodes[level - 1], nodes[level - 1], nodes[level], adrs, seed, hash,
m_xmss_params);
358 node_levels[level - 1]++;
360 adrs.set_tree_height(adrs.get_tree_height() + 1);
364 result = nodes[level - 1];
368 const auto private_key = wots_private_key_for(adrs, hash);
369 return XMSS_WOTS_PublicKey(m_private->wots_parameters(),
m_public_seed, private_key, adrs, hash);
375 return XMSS_WOTS_PrivateKey(
376 m_private->wots_parameters(),
m_public_seed, m_private->private_seed(), adrs, hash);
378 return XMSS_WOTS_PrivateKey(m_private->wots_parameters(), m_private->private_seed(), adrs, hash);
381 throw Invalid_State(
"WOTS derivation method is out of the enum's range");
388size_t XMSS_PrivateKey::reserve_unused_leaf_index() {
389 return m_private->reserve_unused_leaf_index();
393 return m_private->unused_leaf_index();
401 return m_private->remaining_signatures();
405 return m_private->prf_value();
413 return m_private->wots_derivation_method();
422 std::string_view provider)
const {
423 if(provider ==
"base" || provider.empty()) {
424 return std::make_unique<XMSS_Signature_Operation>(*
this);
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_ARG_CHECK(expr, msg)
#define BOTAN_ASSERT(expr, assertion_made)
BER_Decoder & decode(bool &out)
BER_Decoder & verify_end()
secure_vector< uint8_t > get_contents()
DER_Encoder & encode(bool b)
uint64_t current_index(const KeyId &key_id)
void set_index_lower_bound(const KeyId &key_id, uint64_t min)
static Stateful_Key_Index_Registry & global()
uint64_t reserve_next_index(const KeyId &key_id)
uint64_t remaining_operations(const KeyId &key_id, uint64_t max)
size_t worker_count() const
auto run(F &&f, Args &&... args) -> std::future< std::invoke_result_t< F, Args... > >
static Thread_Pool & global_instance()
static void create_l_tree(secure_vector< uint8_t > &result, wots_keysig_t pk, XMSS_Address adrs, const secure_vector< uint8_t > &seed, XMSS_Hash &hash, const XMSS_Parameters ¶ms)
static void randomize_tree_hash(secure_vector< uint8_t > &result, const secure_vector< uint8_t > &left, const secure_vector< uint8_t > &right, XMSS_Address adrs, const secure_vector< uint8_t > &seed, XMSS_Hash &hash, const XMSS_Parameters ¶ms)
std::unique_ptr< Public_Key > public_key() const override
size_t remaining_signatures() const
size_t unused_leaf_index() const
std::optional< uint64_t > remaining_operations() const override
Retrieves the number of remaining operations if this is a stateful private key.
WOTS_Derivation_Method wots_derivation_method() const
secure_vector< uint8_t > raw_private_key() const
secure_vector< uint8_t > private_key_bits() const override
std::unique_ptr< PK_Ops::Signature > create_signature_op(RandomNumberGenerator &rng, std::string_view params, std::string_view provider) const override
XMSS_PrivateKey(XMSS_Parameters::xmss_algorithm_t xmss_algo_id, RandomNumberGenerator &rng, WOTS_Derivation_Method wots_derivation_method=WOTS_Derivation_Method::NIST_SP800_208)
const secure_vector< uint8_t > & root() const
secure_vector< uint8_t > m_root
const secure_vector< uint8_t > & public_seed() const
secure_vector< uint8_t > m_public_seed
const XMSS_Parameters & xmss_parameters() const
XMSS_Parameters m_xmss_params
std::vector< uint8_t > raw_public_key() const
XMSS_WOTS_Parameters m_wots_params
std::string algo_name() const override
XMSS_PublicKey(XMSS_Parameters::xmss_algorithm_t xmss_oid, RandomNumberGenerator &rng)
constexpr RT checked_cast_to(AT i)
constexpr auto concat(Rs &&... ranges)
std::vector< T, secure_allocator< T > > secure_vector
constexpr auto store_be(ParamTs &&... params)
constexpr auto load_be(ParamTs &&... params)