13#include <botan/internal/frodo_matrix.h>
15#include <botan/assert.h>
17#include <botan/internal/bit_ops.h>
18#include <botan/internal/buffer_stuffer.h>
19#include <botan/internal/frodo_constants.h>
20#include <botan/internal/loadstor.h>
26#if defined(BOTAN_HAS_FRODOKEM_AES)
27 #include <botan/internal/frodo_aes_generator.h>
30#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
31 #include <botan/internal/frodo_shake_generator.h>
42std::function<void(std::span<uint8_t> out, uint16_t i)> make_row_generator(
const FrodoKEMConstants& constants,
44#if defined(BOTAN_HAS_FRODOKEM_AES)
45 if(constants.mode().is_aes()) {
50#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
51 if(constants.mode().is_shake()) {
67 const auto n = r.
size() / 2;
69 auto elements = make_elements_vector(
dimensions);
74 for(
auto& elem : elements) {
86 const uint16_t sample_u16 =
static_cast<uint16_t
>(
sample);
88 elem = sign.select(~sample_u16 + 1, sample_u16);
104 m_dim1(std::get<0>(dims)), m_dim2(std::get<1>(dims)), m_elements(make_elements_vector(dims)) {}
112 "FrodoMatrix dimension mismatch of E and S");
114 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
116 auto elements = make_elements_vector(e.
dimensions());
117 auto row_generator = make_row_generator(constants, seed_a);
123 std::vector<uint16_t> a_row_data(4 * constants.
n(), 0);
126 const std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
127 sizeof(uint16_t) * a_row_data.size());
129 for(
size_t i = 0; i < constants.
n(); i += 4) {
133 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 0));
134 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 1));
135 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 2));
136 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 3));
141 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
142 std::array<uint16_t, 4> sum = {0};
143 for(
size_t j = 0; j < constants.
n(); ++j) {
148 const uint32_t sp = s.
elements_at(k * constants.
n() + j);
151 sum.at(0) +=
static_cast<uint16_t
>(a_row_data.at(0 * constants.
n() + j) * sp);
152 sum.at(1) +=
static_cast<uint16_t
>(a_row_data.at(1 * constants.
n() + j) * sp);
153 sum.at(2) +=
static_cast<uint16_t
>(a_row_data.at(2 * constants.
n() + j) * sp);
154 sum.at(3) +=
static_cast<uint16_t
>(a_row_data.at(3 * constants.
n() + j) * sp);
156 elements.at((i + 0) * constants.
n_bar() + k) = e.
elements_at((i + 0) * constants.
n_bar() + k) + sum.at(0);
157 elements.at((i + 3) * constants.
n_bar() + k) = e.
elements_at((i + 3) * constants.
n_bar() + k) + sum.at(3);
158 elements.at((i + 2) * constants.
n_bar() + k) = e.
elements_at((i + 2) * constants.
n_bar() + k) + sum.at(2);
159 elements.at((i + 1) * constants.
n_bar() + k) = e.
elements_at((i + 1) * constants.
n_bar() + k) + sum.at(1);
172 "FrodoMatrix dimension mismatch of E and S");
174 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
176 auto elements = e.m_elements;
177 auto row_generator = make_row_generator(constants, seed_a);
183 std::vector<uint16_t> a_row_data(8 * constants.
n(), 0);
185 const std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
186 sizeof(uint16_t) * a_row_data.size());
189 for(
size_t i = 0; i < constants.
n(); i += 8) {
193 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 0));
194 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 1));
195 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 2));
196 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 3));
197 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 4));
198 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 5));
199 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 6));
200 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 7));
205 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
207 std::array<uint32_t , 8> sp{};
208 for(
size_t p = 0; p < 8; ++p) {
211 for(
size_t q = 0; q < constants.
n(); ++q) {
212 sum = elements.at(j * constants.
n() + q);
213 for(
size_t p = 0; p < 8; ++p) {
214 sum +=
static_cast<uint16_t
>(sp[p] * a_row_data.at(p * constants.
n() + q));
216 elements.at(j * constants.
n() + q) = sum;
230 "FrodoMatrix dimension mismatch of B and S");
232 "FrodoMatrix dimension mismatch of B");
234 "FrodoMatrix dimension mismatch of E");
236 auto elements = make_elements_vector(e.
dimensions());
238 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
239 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
241 for(
size_t j = 0; j < constants.
n(); ++j) {
242 elements.at(k * constants.
n_bar() + i) +=
static_cast<uint16_t
>(
243 static_cast<uint32_t
>(s.
elements_at(k * constants.
n() + j)) *
253 const uint64_t mask = (uint64_t(1) << constants.
b()) - 1;
256 auto elements = make_elements_vector(
dimensions);
261 for(
size_t i = 0; i < (constants.
n_bar() * constants.
n_bar()) / 8; ++i) {
263 for(
size_t j = 0; j < constants.
b(); ++j) {
264 temp |=
static_cast<uint64_t
>(in[i * constants.
b() + j]) << (8 * j);
266 for(
size_t j = 0; j < 8; ++j) {
267 elements.at(pos++) =
static_cast<uint16_t
>((temp & mask) << (constants.
d() - constants.
b()));
268 temp >>= constants.
b();
281 auto elements = make_elements_vector(a.
dimensions());
283 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
296 auto elements = make_elements_vector(a.
dimensions());
298 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
308 return CT::is_equal(
reinterpret_cast<const uint8_t*
>(m_elements.data()),
309 reinterpret_cast<const uint8_t*
>(other.m_elements.data()),
310 sizeof(
decltype(m_elements)::value_type) * m_elements.size());
315 auto elements = make_elements_vector(
dimensions);
317 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
318 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
319 auto& current = elements.at(i * constants.
n_bar() + j);
321 for(
size_t k = 0; k < constants.
n(); ++k) {
323 const uint32_t b_ink = b.
elements_at(i * constants.
n() + k);
326 const uint32_t s_ink = s.
elements_at(j * constants.
n() + k);
328 current +=
static_cast<uint16_t
>(b_ink * s_ink);
359 const uint8_t nbits = std::min(
static_cast<uint8_t
>(8 - b), bits);
360 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
361 const auto t =
static_cast<uint8_t
>((w >> (bits - nbits)) & mask);
362 out[i] = out[i] +
static_cast<uint8_t
>(t << (8 - b - nbits));
368 w = m_elements.at(j);
369 bits =
static_cast<uint8_t
>(constants.
d());
385 for(
unsigned int i = 0; i < m_elements.size(); ++i) {
393 const size_t nwords = (constants.
n_bar() * constants.
n_bar()) / 8;
394 const uint16_t maskex =
static_cast<uint16_t
>(1 << constants.
b()) - 1;
395 const uint16_t maskq =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
400 for(
size_t i = 0; i < nwords; i++) {
401 uint64_t templong = 0;
402 for(
size_t j = 0; j < 8; j++) {
404 static_cast<uint16_t
>(((m_elements.at(index) & maskq) + (1 << (constants.
d() - constants.
b() - 1))) >>
405 (constants.
d() - constants.
b()));
406 templong |=
static_cast<uint64_t
>(temp & maskex) << (constants.
b() * j);
409 for(
size_t j = 0; j < constants.
b(); j++) {
410 out[i * constants.
b() + j] = (templong >> (8 * j)) & 0xFF;
420 const uint8_t lsb =
static_cast<uint8_t
>(constants.
d());
421 const size_t inlen = packed_bytes.
size();
426 auto elements = make_elements_vector(
dimensions);
433 while(i < outlen && (j < inlen || ((j == inlen) && (bits > 0)))) {
447 const uint8_t nbits = std::min(
static_cast<uint8_t
>(lsb - b), bits);
448 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
449 const uint8_t t = (w >> (bits - nbits)) & mask;
451 elements.at(i) = elements.at(i) +
static_cast<uint16_t
>(t << (lsb - b - nbits));
454 w &=
static_cast<uint8_t
>(~(mask << bits));
475 auto elements = make_elements_vector(
dimensions);
483 if(constants.
d() <
sizeof(
decltype(m_elements)::value_type) * 8) {
484 const uint16_t mask =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
485 for(
auto& elem : m_elements) {
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_ASSERT(expr, assertion_made)
#define BOTAN_ASSERT_UNREACHABLE()
Helper class to ease in-place marshalling of concatenated fixed-length values.
static constexpr Mask< T > expand_bit(T v, size_t bit)
static constexpr Mask< T > is_lt(T x, T y)
uint16_t cdf_table_at(size_t i) const
size_t cdf_table_len() const
static FrodoMatrix mul_add_sa_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static std::function< FrodoMatrix(const Dimensions &dimensions)> make_sample_generator(const FrodoKEMConstants &constants, Botan::XOF &shake)
static FrodoMatrix mul_add_sb_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &b, const FrodoMatrix &s, const FrodoMatrix &e)
void reduce(const FrodoKEMConstants &constants)
static FrodoMatrix mul_add_as_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static FrodoMatrix sub(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
FrodoPlaintext decode(const FrodoKEMConstants &constants) const
static FrodoMatrix add(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
std::tuple< size_t, size_t > Dimensions
static FrodoMatrix sample(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoSampleR > r)
CT::Mask< uint8_t > constant_time_compare(const FrodoMatrix &other) const
FrodoPackedMatrix pack(const FrodoKEMConstants &constants) const
static FrodoMatrix encode(const FrodoKEMConstants &constants, StrongSpan< const FrodoPlaintext > in)
Dimensions dimensions() const
FrodoMatrix(Dimensions dims)
static FrodoMatrix unpack(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoPackedMatrix > packed_bytes)
uint16_t elements_at(size_t i) const
size_t element_count() const
static FrodoMatrix mul_bs(const FrodoKEMConstants &constants, const FrodoMatrix &b_p, const FrodoMatrix &s)
static FrodoMatrix deserialize(const Dimensions &dimensions, StrongSpan< const FrodoSerializedMatrix > bytes)
size_t packed_size(const FrodoKEMConstants &constants) const
FrodoSerializedMatrix serialize() const
decltype(auto) data() noexcept(noexcept(this->m_span.data()))
decltype(auto) size() const noexcept(noexcept(this->m_span.size()))
decltype(auto) data() noexcept(noexcept(this->get().data()))
constexpr T value_barrier(T x)
constexpr CT::Mask< T > is_equal(const T x[], const T y[], size_t len)
Strong< secure_vector< uint8_t >, struct FrodoSerializedMatrix_ > FrodoSerializedMatrix
auto create_shake_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
constexpr auto store_le(ParamTs &&... params)
auto create_aes_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
BOTAN_FORCE_INLINE constexpr T ceil_tobytes(T bits)
constexpr auto load_le(ParamTs &&... params)
std::vector< T, secure_allocator< T > > secure_vector
Strong< secure_vector< uint8_t >, struct FrodoPlaintext_ > FrodoPlaintext
Strong< secure_vector< uint8_t >, struct FrodoSampleR_ > FrodoSampleR