13#include <botan/internal/frodo_matrix.h>
15#include <botan/assert.h>
16#include <botan/frodokem.h>
18#include <botan/mem_ops.h>
20#include <botan/internal/bit_ops.h>
21#include <botan/internal/frodo_constants.h>
22#include <botan/internal/loadstor.h>
23#include <botan/internal/stl_util.h>
25#if defined(BOTAN_HAS_FRODOKEM_AES)
26 #include <botan/internal/frodo_aes_generator.h>
29#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
30 #include <botan/internal/frodo_shake_generator.h>
49std::function<void(std::span<uint8_t> out, uint16_t i)> make_row_generator(
const FrodoKEMConstants& constants,
50 StrongSpan<const FrodoSeedA> seed_a) {
51#if defined(BOTAN_HAS_FRODOKEM_AES)
52 if(constants.mode().is_aes()) {
57#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
58 if(constants.mode().is_shake()) {
74 const auto n = r.
size() / 2;
76 auto elements = make_elements_vector(
dimensions);
81 for(
auto& elem : elements) {
93 const uint16_t sample_u16 =
static_cast<uint16_t
>(
sample);
95 elem = sign.select(~sample_u16 + 1, sample_u16);
111 m_dim1(std::get<0>(dims)), m_dim2(std::get<1>(dims)), m_elements(make_elements_vector(dims)) {}
119 "FrodoMatrix dimension mismatch of E and S");
121 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
123 auto elements = make_elements_vector(e.
dimensions());
124 auto row_generator = make_row_generator(constants, seed_a);
130 std::vector<uint16_t> a_row_data(4 * constants.
n(), 0);
133 std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
134 sizeof(uint16_t) * a_row_data.size());
136 for(
size_t i = 0; i < constants.
n(); i += 4) {
140 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 0));
141 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 1));
142 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 2));
143 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 3));
148 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
149 std::array<uint16_t, 4> sum = {0};
150 for(
size_t j = 0; j < constants.
n(); ++j) {
155 const uint32_t sp = s.
elements_at(k * constants.
n() + j);
158 sum.at(0) +=
static_cast<uint16_t
>(a_row_data.at(0 * constants.
n() + j) * sp);
159 sum.at(1) +=
static_cast<uint16_t
>(a_row_data.at(1 * constants.
n() + j) * sp);
160 sum.at(2) +=
static_cast<uint16_t
>(a_row_data.at(2 * constants.
n() + j) * sp);
161 sum.at(3) +=
static_cast<uint16_t
>(a_row_data.at(3 * constants.
n() + j) * sp);
163 elements.at((i + 0) * constants.
n_bar() + k) = e.
elements_at((i + 0) * constants.
n_bar() + k) + sum.at(0);
164 elements.at((i + 3) * constants.
n_bar() + k) = e.
elements_at((i + 3) * constants.
n_bar() + k) + sum.at(3);
165 elements.at((i + 2) * constants.
n_bar() + k) = e.
elements_at((i + 2) * constants.
n_bar() + k) + sum.at(2);
166 elements.at((i + 1) * constants.
n_bar() + k) = e.
elements_at((i + 1) * constants.
n_bar() + k) + sum.at(1);
179 "FrodoMatrix dimension mismatch of E and S");
181 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
183 auto elements = e.m_elements;
184 auto row_generator = make_row_generator(constants, seed_a);
190 std::vector<uint16_t> a_row_data(8 * constants.
n(), 0);
192 std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
193 sizeof(uint16_t) * a_row_data.size());
196 for(
size_t i = 0; i < constants.
n(); i += 8) {
200 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 0));
201 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 1));
202 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 2));
203 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 3));
204 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 4));
205 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 5));
206 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 6));
207 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 7));
212 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
214 std::array<uint32_t , 8> sp;
215 for(
size_t p = 0; p < 8; ++p) {
218 for(
size_t q = 0; q < constants.
n(); ++q) {
219 sum = elements.at(j * constants.
n() + q);
220 for(
size_t p = 0; p < 8; ++p) {
221 sum +=
static_cast<uint16_t
>(sp[p] * a_row_data.at(p * constants.
n() + q));
223 elements.at(j * constants.
n() + q) = sum;
236 std::get<1>(
b.dimensions()) == std::get<0>(s.
dimensions()),
237 "FrodoMatrix dimension mismatch of B and S");
238 BOTAN_ASSERT(std::get<0>(
b.dimensions()) == constants.
n() && std::get<1>(
b.dimensions()) == constants.
n_bar(),
239 "FrodoMatrix dimension mismatch of B");
241 "FrodoMatrix dimension mismatch of E");
243 auto elements = make_elements_vector(e.
dimensions());
245 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
246 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
248 for(
size_t j = 0; j < constants.
n(); ++j) {
249 elements.at(k * constants.
n_bar() + i) +=
static_cast<uint16_t
>(
250 static_cast<uint32_t
>(s.
elements_at(k * constants.
n() + j)) *
251 b.elements_at(j * constants.
n_bar() + i));
260 const uint64_t mask = (uint64_t(1) << constants.
b()) - 1;
263 auto elements = make_elements_vector(
dimensions);
268 for(
size_t i = 0; i < (constants.
n_bar() * constants.
n_bar()) / 8; ++i) {
270 for(
size_t j = 0; j < constants.
b(); ++j) {
271 temp |=
static_cast<uint64_t
>(in[i * constants.
b() + j]) << (8 * j);
273 for(
size_t j = 0; j < 8; ++j) {
274 elements.at(pos++) =
static_cast<uint16_t
>((temp & mask) << (constants.
d() - constants.
b()));
275 temp >>= constants.
b();
288 auto elements = make_elements_vector(a.
dimensions());
290 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
303 auto elements = make_elements_vector(a.
dimensions());
305 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
315 return CT::is_equal(
reinterpret_cast<const uint8_t*
>(m_elements.data()),
316 reinterpret_cast<const uint8_t*
>(other.m_elements.data()),
317 sizeof(
decltype(m_elements)::value_type) * m_elements.size());
322 auto elements = make_elements_vector(
dimensions);
324 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
325 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
326 auto& current = elements.at(i * constants.
n_bar() + j);
328 for(
size_t k = 0; k < constants.
n(); ++k) {
330 const uint32_t b_ink =
b.elements_at(i * constants.
n() + k);
333 const uint32_t s_ink = s.
elements_at(j * constants.
n() + k);
335 current +=
static_cast<uint16_t
>(b_ink * s_ink);
366 const uint8_t nbits = std::min(
static_cast<uint8_t
>(8 -
b), bits);
367 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
368 const auto t =
static_cast<uint8_t
>((w >> (bits - nbits)) & mask);
369 out[i] = out[i] +
static_cast<uint8_t
>(t << (8 -
b - nbits));
375 w = m_elements.at(j);
376 bits =
static_cast<uint8_t
>(constants.
d());
392 for(
unsigned int i = 0; i < m_elements.size(); ++i) {
400 const size_t nwords = (constants.
n_bar() * constants.
n_bar()) / 8;
401 const uint16_t maskex =
static_cast<uint16_t
>(1 << constants.
b()) - 1;
402 const uint16_t maskq =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
407 for(
size_t i = 0; i < nwords; i++) {
408 uint64_t templong = 0;
409 for(
size_t j = 0; j < 8; j++) {
411 static_cast<uint16_t
>(((m_elements.at(index) & maskq) + (1 << (constants.
d() - constants.
b() - 1))) >>
412 (constants.
d() - constants.
b()));
413 templong |=
static_cast<uint64_t
>(temp & maskex) << (constants.
b() * j);
416 for(
size_t j = 0; j < constants.
b(); j++) {
417 out[i * constants.
b() + j] = (templong >> (8 * j)) & 0xFF;
427 const uint8_t lsb =
static_cast<uint8_t
>(constants.
d());
428 const size_t inlen = packed_bytes.
size();
433 auto elements = make_elements_vector(
dimensions);
440 while(i < outlen && (j < inlen || ((j == inlen) && (bits > 0)))) {
454 const uint8_t nbits = std::min(
static_cast<uint8_t
>(lsb -
b), bits);
455 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
456 uint8_t t = (w >> (bits - nbits)) & mask;
458 elements.at(i) = elements.at(i) +
static_cast<uint16_t
>(t << (lsb -
b - nbits));
461 w &=
static_cast<uint8_t
>(~(mask << bits));
482 auto elements = make_elements_vector(
dimensions);
490 if(constants.
d() <
sizeof(
decltype(m_elements)::value_type) * 8) {
491 const uint16_t mask =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
492 for(
auto& elem : m_elements) {
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_ASSERT(expr, assertion_made)
#define BOTAN_ASSERT_UNREACHABLE()
Helper class to ease in-place marshalling of concatenated fixed-length values.
static constexpr Mask< T > expand_bit(T v, size_t bit)
static constexpr Mask< T > is_lt(T x, T y)
uint16_t cdf_table_at(size_t i) const
size_t cdf_table_len() const
static FrodoMatrix mul_add_sa_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static std::function< FrodoMatrix(const Dimensions &dimensions)> make_sample_generator(const FrodoKEMConstants &constants, Botan::XOF &shake)
static FrodoMatrix mul_add_sb_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &b, const FrodoMatrix &s, const FrodoMatrix &e)
void reduce(const FrodoKEMConstants &constants)
static FrodoMatrix mul_add_as_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static FrodoMatrix sub(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
FrodoPlaintext decode(const FrodoKEMConstants &constants) const
static FrodoMatrix add(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
std::tuple< size_t, size_t > Dimensions
static FrodoMatrix sample(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoSampleR > r)
CT::Mask< uint8_t > constant_time_compare(const FrodoMatrix &other) const
FrodoPackedMatrix pack(const FrodoKEMConstants &constants) const
static FrodoMatrix encode(const FrodoKEMConstants &constants, StrongSpan< const FrodoPlaintext > in)
Dimensions dimensions() const
FrodoMatrix(Dimensions dims)
static FrodoMatrix unpack(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoPackedMatrix > packed_bytes)
uint16_t elements_at(size_t i) const
size_t element_count() const
static FrodoMatrix mul_bs(const FrodoKEMConstants &constants, const FrodoMatrix &b_p, const FrodoMatrix &s)
static FrodoMatrix deserialize(const Dimensions &dimensions, StrongSpan< const FrodoSerializedMatrix > bytes)
size_t packed_size(const FrodoKEMConstants &constants) const
FrodoSerializedMatrix serialize() const
decltype(auto) data() noexcept(noexcept(this->m_span.data()))
decltype(auto) size() const noexcept(noexcept(this->m_span.size()))
decltype(auto) data() noexcept(noexcept(this->get().data()))
constexpr T value_barrier(T x)
constexpr CT::Mask< T > is_equal(const T x[], const T y[], size_t len)
auto create_shake_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
constexpr auto store_le(ParamTs &&... params)
auto create_aes_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
constexpr auto load_le(ParamTs &&... params)
std::vector< T, secure_allocator< T > > secure_vector
constexpr T ceil_tobytes(T bits)