13#include <botan/internal/frodo_matrix.h>
15#include <botan/assert.h>
16#include <botan/frodokem.h>
18#include <botan/mem_ops.h>
20#include <botan/internal/bit_ops.h>
21#include <botan/internal/frodo_constants.h>
22#include <botan/internal/loadstor.h>
23#include <botan/internal/stl_util.h>
25#if defined(BOTAN_HAS_FRODOKEM_AES)
26 #include <botan/internal/frodo_aes_generator.h>
29#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
30 #include <botan/internal/frodo_shake_generator.h>
46 return secure_vector<uint16_t>(
static_cast<size_t>(std::get<0>(dimensions)) * std::get<1>(dimensions));
49std::function<void(std::span<uint8_t> out, uint16_t i)> make_row_generator(
const FrodoKEMConstants& constants,
50 StrongSpan<const FrodoSeedA> seed_a) {
51#if defined(BOTAN_HAS_FRODOKEM_AES)
52 if(constants.mode().is_aes()) {
57#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
58 if(constants.mode().is_shake()) {
74 const auto n = r.
size() / 2;
76 auto elements = make_elements_vector(
dimensions);
79 load_le<uint16_t>(elements.data(), r.
data(), n);
81 for(
size_t i = 0; i < n; ++i) {
83 const uint16_t prnd = elements.at(i) >> 1;
84 const uint16_t sign = elements.at(i) & 0x1;
92 elements.at(i) =
static_cast<uint16_t
>((-sign) ^
sample) + sign;
108 m_dim1(std::get<0>(dims)), m_dim2(std::get<1>(dims)), m_elements(make_elements_vector(dims)) {}
116 "FrodoMatrix dimension mismatch of E and S");
118 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
120 auto elements = make_elements_vector(e.
dimensions());
121 auto row_generator = make_row_generator(constants, seed_a);
127 std::vector<uint16_t> a_row_data(4 * constants.
n(), 0);
130 std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
131 sizeof(uint16_t) * a_row_data.size());
133 for(
size_t i = 0; i < constants.
n(); i += 4) {
137 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 0));
138 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 1));
139 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 2));
140 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 3));
143 load_le<uint16_t>(a_row_data.data(), a_row_data_bytes.data(), 4 * constants.
n());
145 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
146 std::array<uint16_t, 4> sum = {0};
147 for(
size_t j = 0; j < constants.
n(); ++j) {
152 const uint32_t sp = s.
elements_at(k * constants.
n() + j);
155 sum.at(0) +=
static_cast<uint16_t
>(a_row_data.at(0 * constants.
n() + j) * sp);
156 sum.at(1) +=
static_cast<uint16_t
>(a_row_data.at(1 * constants.
n() + j) * sp);
157 sum.at(2) +=
static_cast<uint16_t
>(a_row_data.at(2 * constants.
n() + j) * sp);
158 sum.at(3) +=
static_cast<uint16_t
>(a_row_data.at(3 * constants.
n() + j) * sp);
160 elements.at((i + 0) * constants.
n_bar() + k) = e.
elements_at((i + 0) * constants.
n_bar() + k) + sum.at(0);
161 elements.at((i + 3) * constants.
n_bar() + k) = e.
elements_at((i + 3) * constants.
n_bar() + k) + sum.at(3);
162 elements.at((i + 2) * constants.
n_bar() + k) = e.
elements_at((i + 2) * constants.
n_bar() + k) + sum.at(2);
163 elements.at((i + 1) * constants.
n_bar() + k) = e.
elements_at((i + 1) * constants.
n_bar() + k) + sum.at(1);
176 "FrodoMatrix dimension mismatch of E and S");
178 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
180 auto elements = e.m_elements;
181 auto row_generator = make_row_generator(constants, seed_a);
187 std::vector<uint16_t> a_row_data(8 * constants.
n(), 0);
189 std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
190 sizeof(uint16_t) * a_row_data.size());
193 for(
size_t i = 0; i < constants.
n(); i += 8) {
197 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 0));
198 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 1));
199 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 2));
200 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 3));
201 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 4));
202 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 5));
203 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 6));
204 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 7));
207 load_le<uint16_t>(a_row_data.data(), a_row_data_bytes.data(), 8 * constants.
n());
209 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
211 std::array<uint32_t , 8> sp;
212 for(
size_t p = 0; p < 8; ++p) {
215 for(
size_t q = 0; q < constants.
n(); ++q) {
216 sum = elements.at(j * constants.
n() + q);
217 for(
size_t p = 0; p < 8; ++p) {
218 sum +=
static_cast<uint16_t
>(sp[p] * a_row_data.at(p * constants.
n() + q));
220 elements.at(j * constants.
n() + q) = sum;
234 "FrodoMatrix dimension mismatch of B and S");
236 "FrodoMatrix dimension mismatch of B");
238 "FrodoMatrix dimension mismatch of E");
240 auto elements = make_elements_vector(e.
dimensions());
242 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
243 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
245 for(
size_t j = 0; j < constants.
n(); ++j) {
246 elements.at(k * constants.
n_bar() + i) +=
static_cast<uint16_t
>(
247 static_cast<uint32_t
>(s.
elements_at(k * constants.
n() + j)) *
257 const uint64_t mask = (uint64_t(1) << constants.
b()) - 1;
260 auto elements = make_elements_vector(
dimensions);
265 for(
size_t i = 0; i < (constants.
n_bar() * constants.
n_bar()) / 8; ++i) {
267 for(
size_t j = 0; j < constants.
b(); ++j) {
268 temp |=
static_cast<uint64_t
>(in[i * constants.
b() + j]) << (8 * j);
270 for(
size_t j = 0; j < 8; ++j) {
271 elements.at(pos++) =
static_cast<uint16_t
>((temp & mask) << (constants.
d() - constants.
b()));
272 temp >>= constants.
b();
285 auto elements = make_elements_vector(a.
dimensions());
287 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
300 auto elements = make_elements_vector(a.
dimensions());
302 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
312 return CT::is_equal(
reinterpret_cast<const uint8_t*
>(m_elements.data()),
313 reinterpret_cast<const uint8_t*
>(other.m_elements.data()),
314 sizeof(
decltype(m_elements)::value_type) * m_elements.size());
319 auto elements = make_elements_vector(
dimensions);
321 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
322 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
323 auto& current = elements.at(i * constants.
n_bar() + j);
325 for(
size_t k = 0; k < constants.
n(); ++k) {
327 const uint32_t b_ink = b.
elements_at(i * constants.
n() + k);
330 const uint32_t s_ink = s.
elements_at(j * constants.
n() + k);
332 current +=
static_cast<uint16_t
>(b_ink * s_ink);
363 const uint8_t nbits = std::min(
static_cast<uint8_t
>(8 - b), bits);
364 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
365 const auto t =
static_cast<uint8_t
>((w >> (bits - nbits)) & mask);
366 out[i] = out[i] +
static_cast<uint8_t
>(t << (8 - b - nbits));
372 w = m_elements.at(j);
373 bits =
static_cast<uint8_t
>(constants.
d());
389 for(
unsigned int i = 0; i < m_elements.size(); ++i) {
397 const size_t nwords = (constants.
n_bar() * constants.
n_bar()) / 8;
398 const uint16_t maskex =
static_cast<uint16_t
>(1 << constants.
b()) - 1;
399 const uint16_t maskq =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
404 for(
size_t i = 0; i < nwords; i++) {
405 uint64_t templong = 0;
406 for(
size_t j = 0; j < 8; j++) {
408 static_cast<uint16_t
>(((m_elements.at(index) & maskq) + (1 << (constants.
d() - constants.
b() - 1))) >>
409 (constants.
d() - constants.
b()));
410 templong |=
static_cast<uint64_t
>(temp & maskex) << (constants.
b() * j);
413 for(
size_t j = 0; j < constants.
b(); j++) {
414 out[i * constants.
b() + j] = (templong >> (8 * j)) & 0xFF;
424 const uint8_t lsb =
static_cast<uint8_t
>(constants.
d());
425 const size_t inlen = packed_bytes.
size();
430 auto elements = make_elements_vector(
dimensions);
437 while(i < outlen && (j < inlen || ((j == inlen) && (bits > 0)))) {
451 const uint8_t nbits = std::min(
static_cast<uint8_t
>(lsb - b), bits);
452 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
453 uint8_t t = (w >> (bits - nbits)) & mask;
455 elements.at(i) = elements.at(i) +
static_cast<uint16_t
>(t << (lsb - b - nbits));
458 w &=
static_cast<uint8_t
>(~(mask << bits));
479 auto elements = make_elements_vector(
dimensions);
481 load_le<uint16_t>(elements.data(), bytes.
data(), elements.size());
487 if(constants.
d() <
sizeof(
decltype(m_elements)::value_type) * 8) {
488 const uint16_t mask =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
489 for(
auto& elem : m_elements) {
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_ASSERT(expr, assertion_made)
#define BOTAN_ASSERT_UNREACHABLE()
Helper class to ease in-place marshalling of concatenated fixed-length values.
static constexpr Mask< T > is_lt(T x, T y)
uint16_t cdf_table_at(size_t i) const
size_t cdf_table_len() const
static std::function< FrodoMatrix(const Dimensions &dimensions) make_sample_generator)(const FrodoKEMConstants &constants, Botan::XOF &shake)
static FrodoMatrix mul_add_sa_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static FrodoMatrix mul_add_sb_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &b, const FrodoMatrix &s, const FrodoMatrix &e)
void reduce(const FrodoKEMConstants &constants)
static FrodoMatrix mul_add_as_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static FrodoMatrix sub(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
FrodoPlaintext decode(const FrodoKEMConstants &constants) const
static FrodoMatrix add(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
std::tuple< size_t, size_t > Dimensions
static FrodoMatrix sample(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoSampleR > r)
CT::Mask< uint8_t > constant_time_compare(const FrodoMatrix &other) const
FrodoPackedMatrix pack(const FrodoKEMConstants &constants) const
static FrodoMatrix encode(const FrodoKEMConstants &constants, StrongSpan< const FrodoPlaintext > in)
Dimensions dimensions() const
FrodoMatrix(Dimensions dims)
static FrodoMatrix unpack(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoPackedMatrix > packed_bytes)
uint16_t elements_at(size_t i) const
size_t element_count() const
static FrodoMatrix mul_bs(const FrodoKEMConstants &constants, const FrodoMatrix &b_p, const FrodoMatrix &s)
static FrodoMatrix deserialize(const Dimensions &dimensions, StrongSpan< const FrodoSerializedMatrix > bytes)
size_t packed_size(const FrodoKEMConstants &constants) const
FrodoSerializedMatrix serialize() const
decltype(auto) data() noexcept(noexcept(this->m_span.data()))
decltype(auto) size() const noexcept(noexcept(this->m_span.size()))
decltype(auto) data() noexcept(noexcept(this->get().data()))
constexpr CT::Mask< T > is_equal(const T x[], const T y[], size_t len)
auto create_shake_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
constexpr auto store_le(ParamTs &&... params)
auto create_aes_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
constexpr T ceil_tobytes(T bits)