10#include <botan/certstor_windows.h>
12#include <botan/assert.h>
13#include <botan/ber_dec.h>
14#include <botan/hash.h>
15#include <botan/mutex.h>
16#include <botan/pkix_types.h>
17#include <botan/internal/fmt.h>
18#include <botan/internal/x509_cert_cache.h>
22#include <unordered_map>
23#include <unordered_set>
35constexpr std::array<const char*, 2> cert_store_names{
"Root",
"CA"};
46class Cert_Context final {
48 Cert_Context() : m_ctx(
nullptr) {}
51 if(m_ctx !=
nullptr) {
52 CertFreeCertificateContext(m_ctx);
56 Cert_Context(
const Cert_Context&) =
delete;
57 Cert_Context(Cert_Context&&) =
delete;
58 Cert_Context& operator=(
const Cert_Context&) =
delete;
59 Cert_Context& operator=(Cert_Context&&) =
delete;
61 bool assign(PCCERT_CONTEXT ctx) {
63 return m_ctx !=
nullptr;
66 PCCERT_CONTEXT get()
const {
return m_ctx; }
68 PCCERT_CONTEXT operator->()
const {
return m_ctx; }
82class Cert_Enumerator final {
84 using Next_Fn = std::function<PCCERT_CONTEXT(HCERTSTORE, PCCERT_CONTEXT)>;
86 Cert_Enumerator(std::span<const HCERTSTORE> stores, Next_Fn fn) : m_stores(stores), m_get_next(std::move(fn)) {}
88 Cert_Enumerator(
const Cert_Enumerator&) =
delete;
89 Cert_Enumerator(Cert_Enumerator&&) =
delete;
90 Cert_Enumerator& operator=(
const Cert_Enumerator&) =
delete;
91 Cert_Enumerator& operator=(Cert_Enumerator&&) =
delete;
92 ~Cert_Enumerator() =
default;
94 PCCERT_CONTEXT next() {
95 while(m_store_idx < m_stores.size()) {
96 if(m_ctx.assign(m_get_next(m_stores[m_store_idx], m_ctx.get()))) {
108 std::span<const HCERTSTORE> m_stores;
110 size_t m_store_idx = 0;
118class Pubkey_SHA1 final {
120 static constexpr size_t LEN = 20;
122 explicit Pubkey_SHA1(std::span<const uint8_t> bytes) {
123 BOTAN_ARG_CHECK(bytes.size() == LEN,
"invalid SHA-1 pubkey hash length");
124 std::copy(bytes.begin(), bytes.end(), m_hash.begin());
127 static Pubkey_SHA1 compute(HashFunction& sha1, std::span<const uint8_t> data) {
130 sha1.final(result.m_hash);
134 auto operator<=>(
const Pubkey_SHA1&)
const =
default;
136 size_t hash()
const noexcept {
138 std::memcpy(&h, m_hash.data(),
sizeof(h));
143 Pubkey_SHA1() =
default;
144 std::array<uint8_t, LEN> m_hash = {};
147struct Pubkey_SHA1_Hasher {
148 size_t operator()(
const Pubkey_SHA1& h)
const noexcept {
return h.hash(); }
156class Certificate_Store_Windows_Impl final {
160 static constexpr size_t SystemStore_CertCacheSize = 128;
162 Certificate_Store_Windows_Impl() : m_cert_cache(SystemStore_CertCacheSize) {
163 for(
const auto* cert_store_name : cert_store_names) {
164 auto* store = CertOpenSystemStoreA(0, cert_store_name);
165 if(store ==
nullptr) {
166 const auto err = ::GetLastError();
168 throw System_Error(
fmt(
"Failed to open Windows certificate store '{}'", cert_store_name), err);
171 CertControlStore(store, 0, CERT_STORE_CTRL_AUTO_RESYNC,
nullptr);
173 m_stores.push_back(store);
177 ~Certificate_Store_Windows_Impl() { close_stores(); }
179 Certificate_Store_Windows_Impl(
const Certificate_Store_Windows_Impl&) =
delete;
180 Certificate_Store_Windows_Impl(Certificate_Store_Windows_Impl&&) =
delete;
181 Certificate_Store_Windows_Impl& operator=(
const Certificate_Store_Windows_Impl&) =
delete;
182 Certificate_Store_Windows_Impl& operator=(Certificate_Store_Windows_Impl&&) =
delete;
184 std::optional<X509_Certificate> find_cert(
const X509_DN& subject_dn,
const std::vector<uint8_t>& key_id) {
187 const auto certs = find_cert_by_dn_and_key_id(subject_dn, key_id,
true);
191 return certs.front();
194 std::vector<X509_Certificate> find_all_certs(
const X509_DN& subject_dn,
const std::vector<uint8_t>& key_id) {
197 return find_cert_by_dn_and_key_id(subject_dn, key_id,
false);
200 std::optional<X509_Certificate> find_cert_by_pubkey_sha1(
const std::vector<uint8_t>& key_hash) {
201 if(key_hash.size() != Pubkey_SHA1::LEN) {
202 throw Invalid_Argument(
"Certificate_Store_Windows::find_cert_by_pubkey_sha1 invalid hash");
205 const Pubkey_SHA1 target(key_hash);
209 if(
const auto hit = m_sha1_pubkey_to_cert.find(target); hit != m_sha1_pubkey_to_cert.end()) {
214 while(m_sha1_pubkey_to_cert.size() >= 1024) {
215 m_sha1_pubkey_to_cert.erase(m_sha1_pubkey_to_cert.begin());
220 Cert_Enumerator enumerator(
221 m_stores, [](HCERTSTORE store, PCCERT_CONTEXT prev) {
return CertEnumCertificatesInStore(store, prev); });
223 while(
const auto* ctx = enumerator.next()) {
224 const auto pubkey_blob = ctx->pCertInfo->SubjectPublicKeyInfo.PublicKey;
225 const auto candidate =
226 Pubkey_SHA1::compute(*sha1, {
static_cast<uint8_t*
>(pubkey_blob.pbData), pubkey_blob.cbData});
228 if(candidate == target) {
229 auto result = materialize(ctx->pbCertEncoded, ctx->cbCertEncoded);
230 m_sha1_pubkey_to_cert.emplace(target, result);
236 m_sha1_pubkey_to_cert.emplace(target, std::nullopt);
240 std::optional<X509_Certificate> find_cert_by_issuer_dn_and_serial_number(
const X509_DN& issuer_dn,
241 std::span<const uint8_t> serial_number) {
244 const std::vector<uint8_t> dn_data = issuer_dn.BER_encode();
246 const _CRYPTOAPI_BLOB blob{
247 .cbData =
static_cast<DWORD
>(dn_data.size()),
248 .pbData =
const_cast<BYTE*
>(dn_data.data()),
251 auto filter = [&](
const X509_Certificate& cert) {
252 return std::ranges::equal(cert.serial_number(), serial_number);
255 const auto certs = search_cert_stores(blob, CERT_FIND_ISSUER_NAME, filter,
true);
259 return certs.front();
262 bool contains(
const X509_Certificate& cert) {
263 const auto cert_sha1 = cert.certificate_data_sha1();
264 const auto cert_sha256 = cert.certificate_data_sha256();
266 const CRYPT_HASH_BLOB sha1_blob{
267 .cbData =
static_cast<DWORD
>(cert_sha1.size()),
268 .pbData =
const_cast<BYTE*
>(cert_sha1.data()),
273 Cert_Enumerator enumerator(m_stores, [&sha1_blob](HCERTSTORE store, PCCERT_CONTEXT prev) {
274 return CertFindCertificateInStore(
275 store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_SHA1_HASH, &sha1_blob, prev);
278 while(
const auto* ctx = enumerator.next()) {
279 const auto found = materialize(ctx->pbCertEncoded, ctx->cbCertEncoded);
280 if(std::ranges::equal(found.certificate_data_sha256(), cert_sha256)) {
288 std::vector<X509_DN> all_subjects() {
291 std::vector<X509_DN> subject_dns;
293 Cert_Enumerator enumerator(
294 m_stores, [](HCERTSTORE store, PCCERT_CONTEXT prev) {
return CertEnumCertificatesInStore(store, prev); });
296 while(
const auto* ctx = enumerator.next()) {
297 BER_Decoder dec(ctx->pCertInfo->Subject.pbData, ctx->pCertInfo->Subject.cbData);
300 subject_dns.emplace_back(std::move(dn));
307 void close_stores() {
308 for(
auto* store : m_stores) {
309 CertCloseStore(store, 0);
314 X509_Certificate materialize(
const BYTE* der, DWORD len) {
return m_cert_cache.
find_or_insert({der, len}); }
317 std::vector<X509_Certificate> find_cert_by_dn_and_key_id(
const X509_DN& subject_dn,
318 const std::vector<uint8_t>& key_id,
319 bool return_on_first_found) {
320 _CRYPTOAPI_BLOB blob{};
322 std::vector<uint8_t> dn_data;
326 find_type = CERT_FIND_SUBJECT_NAME;
327 dn_data = subject_dn.DER_encode();
328 blob.cbData =
static_cast<DWORD
>(dn_data.size());
329 blob.pbData =
reinterpret_cast<BYTE*
>(dn_data.data());
331 find_type = CERT_FIND_KEY_IDENTIFIER;
332 blob.cbData =
static_cast<DWORD
>(key_id.size());
333 blob.pbData =
const_cast<BYTE*
>(key_id.data());
336 auto filter = [&](
const X509_Certificate& cert) {
return key_id.empty() || cert.subject_dn() == subject_dn; };
338 return search_cert_stores(blob, find_type, filter, return_on_first_found);
342 std::vector<X509_Certificate> search_cert_stores(
const _CRYPTOAPI_BLOB& blob,
344 const std::function<
bool(
const X509_Certificate&)>& filter,
345 bool return_on_first_found) {
346 std::vector<X509_Certificate> certs;
347 std::unordered_set<X509_Certificate::Tag, X509_Certificate::TagHash> seen;
349 Cert_Enumerator enumerator(m_stores, [&blob, find_type](HCERTSTORE store, PCCERT_CONTEXT prev) {
350 return CertFindCertificateInStore(
351 store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, find_type, &blob, prev);
354 while(
const auto* ctx = enumerator.next()) {
355 auto cert = materialize(ctx->pbCertEncoded, ctx->cbCertEncoded);
356 if(!seen.insert(cert.tag()).second) {
360 certs.push_back(std::move(cert));
361 if(return_on_first_found) {
371 std::vector<HCERTSTORE> m_stores;
372 std::unordered_map<Pubkey_SHA1, std::optional<X509_Certificate>, Pubkey_SHA1_Hasher> m_sha1_pubkey_to_cert;
373 X509_Certificate_Cache m_cert_cache;
379 return m_impl->all_subjects();
383 const std::vector<uint8_t>& key_id)
const {
384 return m_impl->find_cert(subject_dn, key_id);
388 const std::vector<uint8_t>& key_id)
const {
389 return m_impl->find_all_certs(subject_dn, key_id);
393 const std::vector<uint8_t>& key_hash)
const {
394 return m_impl->find_cert_by_pubkey_sha1(key_hash);
398 const std::vector<uint8_t>& subject_hash)
const {
400 throw Not_Implemented(
"Certificate_Store_Windows::find_cert_by_raw_subject_dn_sha256");
404 const X509_DN& issuer_dn, std::span<const uint8_t> serial_number)
const {
405 return m_impl->find_cert_by_issuer_dn_and_serial_number(issuer_dn, serial_number);
415 return m_impl->contains(cert);
#define BOTAN_ARG_CHECK(expr, msg)
Certificate_Store_Windows()
bool contains(const X509_Certificate &cert) const override
std::optional< X509_Certificate > find_cert_by_issuer_dn_and_serial_number(const X509_DN &issuer_dn, std::span< const uint8_t > serial_number) const override
std::vector< X509_DN > all_subjects() const override
std::optional< X509_Certificate > find_cert_by_pubkey_sha1(const std::vector< uint8_t > &key_hash) const override
std::optional< X509_Certificate > find_cert(const X509_DN &subject_dn, const std::vector< uint8_t > &key_id) const override
std::optional< X509_Certificate > find_cert_by_raw_subject_dn_sha256(const std::vector< uint8_t > &subject_hash) const override
std::optional< X509_CRL > find_crl_for(const X509_Certificate &subject) const override
std::vector< X509_Certificate > find_all_certs(const X509_DN &subject_dn, const std::vector< uint8_t > &key_id) const override
static std::unique_ptr< HashFunction > create_or_throw(std::string_view algo_spec, std::string_view provider="")
X509_Certificate find_or_insert(std::span< const uint8_t > encoding)
std::string fmt(std::string_view format, const T &... args)
secure_vector< T > lock(const std::vector< T > &in)
auto operator<=>(const Strong< T, Tags... > &lhs, const Strong< T, Tags... > &rhs)
lock_guard< T > lock_guard_type