13#include <botan/internal/frodo_matrix.h>
15#include <botan/assert.h>
17#include <botan/internal/bit_ops.h>
18#include <botan/internal/frodo_constants.h>
19#include <botan/internal/loadstor.h>
20#include <botan/internal/stl_util.h>
22#if defined(BOTAN_HAS_FRODOKEM_AES)
23 #include <botan/internal/frodo_aes_generator.h>
26#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
27 #include <botan/internal/frodo_shake_generator.h>
43std::function<void(std::span<uint8_t> out, uint16_t i)> make_row_generator(
const FrodoKEMConstants& constants,
45#if defined(BOTAN_HAS_FRODOKEM_AES)
46 if(constants.mode().is_aes()) {
51#if defined(BOTAN_HAS_FRODOKEM_SHAKE)
52 if(constants.mode().is_shake()) {
68 const auto n = r.
size() / 2;
70 auto elements = make_elements_vector(
dimensions);
75 for(
auto& elem : elements) {
87 const uint16_t sample_u16 =
static_cast<uint16_t
>(
sample);
89 elem = sign.select(~sample_u16 + 1, sample_u16);
105 m_dim1(std::get<0>(dims)), m_dim2(std::get<1>(dims)), m_elements(make_elements_vector(dims)) {}
113 "FrodoMatrix dimension mismatch of E and S");
115 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
117 auto elements = make_elements_vector(e.
dimensions());
118 auto row_generator = make_row_generator(constants, seed_a);
124 std::vector<uint16_t> a_row_data(4 * constants.
n(), 0);
127 std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
128 sizeof(uint16_t) * a_row_data.size());
130 for(
size_t i = 0; i < constants.
n(); i += 4) {
134 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 0));
135 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 1));
136 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 2));
137 row_generator(a_row.next(constants.
n() *
sizeof(uint16_t)),
static_cast<uint16_t
>(i + 3));
142 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
143 std::array<uint16_t, 4> sum = {0};
144 for(
size_t j = 0; j < constants.
n(); ++j) {
149 const uint32_t sp = s.
elements_at(k * constants.
n() + j);
152 sum.at(0) +=
static_cast<uint16_t
>(a_row_data.at(0 * constants.
n() + j) * sp);
153 sum.at(1) +=
static_cast<uint16_t
>(a_row_data.at(1 * constants.
n() + j) * sp);
154 sum.at(2) +=
static_cast<uint16_t
>(a_row_data.at(2 * constants.
n() + j) * sp);
155 sum.at(3) +=
static_cast<uint16_t
>(a_row_data.at(3 * constants.
n() + j) * sp);
157 elements.at((i + 0) * constants.
n_bar() + k) = e.
elements_at((i + 0) * constants.
n_bar() + k) + sum.at(0);
158 elements.at((i + 3) * constants.
n_bar() + k) = e.
elements_at((i + 3) * constants.
n_bar() + k) + sum.at(3);
159 elements.at((i + 2) * constants.
n_bar() + k) = e.
elements_at((i + 2) * constants.
n_bar() + k) + sum.at(2);
160 elements.at((i + 1) * constants.
n_bar() + k) = e.
elements_at((i + 1) * constants.
n_bar() + k) + sum.at(1);
173 "FrodoMatrix dimension mismatch of E and S");
175 "FrodoMatrix dimension mismatch of new matrix dimensions and E");
177 auto elements = e.m_elements;
178 auto row_generator = make_row_generator(constants, seed_a);
184 std::vector<uint16_t> a_row_data(8 * constants.
n(), 0);
186 std::span<uint8_t> a_row_data_bytes(
reinterpret_cast<uint8_t*
>(a_row_data.data()),
187 sizeof(uint16_t) * a_row_data.size());
190 for(
size_t i = 0; i < constants.
n(); i += 8) {
194 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 0));
195 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 1));
196 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 2));
197 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 3));
198 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 4));
199 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 5));
200 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 6));
201 row_generator(a_row.next(
sizeof(uint16_t) * constants.
n()),
static_cast<uint16_t
>(i + 7));
206 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
208 std::array<uint32_t , 8> sp{};
209 for(
size_t p = 0; p < 8; ++p) {
212 for(
size_t q = 0; q < constants.
n(); ++q) {
213 sum = elements.at(j * constants.
n() + q);
214 for(
size_t p = 0; p < 8; ++p) {
215 sum +=
static_cast<uint16_t
>(sp[p] * a_row_data.at(p * constants.
n() + q));
217 elements.at(j * constants.
n() + q) = sum;
231 "FrodoMatrix dimension mismatch of B and S");
233 "FrodoMatrix dimension mismatch of B");
235 "FrodoMatrix dimension mismatch of E");
237 auto elements = make_elements_vector(e.
dimensions());
239 for(
size_t k = 0; k < constants.
n_bar(); ++k) {
240 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
242 for(
size_t j = 0; j < constants.
n(); ++j) {
243 elements.at(k * constants.
n_bar() + i) +=
static_cast<uint16_t
>(
244 static_cast<uint32_t
>(s.
elements_at(k * constants.
n() + j)) *
254 const uint64_t mask = (uint64_t(1) << constants.
b()) - 1;
257 auto elements = make_elements_vector(
dimensions);
262 for(
size_t i = 0; i < (constants.
n_bar() * constants.
n_bar()) / 8; ++i) {
264 for(
size_t j = 0; j < constants.
b(); ++j) {
265 temp |=
static_cast<uint64_t
>(in[i * constants.
b() + j]) << (8 * j);
267 for(
size_t j = 0; j < 8; ++j) {
268 elements.at(pos++) =
static_cast<uint16_t
>((temp & mask) << (constants.
d() - constants.
b()));
269 temp >>= constants.
b();
282 auto elements = make_elements_vector(a.
dimensions());
284 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
297 auto elements = make_elements_vector(a.
dimensions());
299 for(
size_t i = 0; i < constants.
n_bar() * constants.
n_bar(); ++i) {
309 return CT::is_equal(
reinterpret_cast<const uint8_t*
>(m_elements.data()),
310 reinterpret_cast<const uint8_t*
>(other.m_elements.data()),
311 sizeof(
decltype(m_elements)::value_type) * m_elements.size());
316 auto elements = make_elements_vector(
dimensions);
318 for(
size_t i = 0; i < constants.
n_bar(); ++i) {
319 for(
size_t j = 0; j < constants.
n_bar(); ++j) {
320 auto& current = elements.at(i * constants.
n_bar() + j);
322 for(
size_t k = 0; k < constants.
n(); ++k) {
324 const uint32_t b_ink = b.
elements_at(i * constants.
n() + k);
327 const uint32_t s_ink = s.
elements_at(j * constants.
n() + k);
329 current +=
static_cast<uint16_t
>(b_ink * s_ink);
360 const uint8_t nbits = std::min(
static_cast<uint8_t
>(8 - b), bits);
361 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
362 const auto t =
static_cast<uint8_t
>((w >> (bits - nbits)) & mask);
363 out[i] = out[i] +
static_cast<uint8_t
>(t << (8 - b - nbits));
369 w = m_elements.at(j);
370 bits =
static_cast<uint8_t
>(constants.
d());
386 for(
unsigned int i = 0; i < m_elements.size(); ++i) {
394 const size_t nwords = (constants.
n_bar() * constants.
n_bar()) / 8;
395 const uint16_t maskex =
static_cast<uint16_t
>(1 << constants.
b()) - 1;
396 const uint16_t maskq =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
401 for(
size_t i = 0; i < nwords; i++) {
402 uint64_t templong = 0;
403 for(
size_t j = 0; j < 8; j++) {
405 static_cast<uint16_t
>(((m_elements.at(index) & maskq) + (1 << (constants.
d() - constants.
b() - 1))) >>
406 (constants.
d() - constants.
b()));
407 templong |=
static_cast<uint64_t
>(temp & maskex) << (constants.
b() * j);
410 for(
size_t j = 0; j < constants.
b(); j++) {
411 out[i * constants.
b() + j] = (templong >> (8 * j)) & 0xFF;
421 const uint8_t lsb =
static_cast<uint8_t
>(constants.
d());
422 const size_t inlen = packed_bytes.
size();
427 auto elements = make_elements_vector(
dimensions);
434 while(i < outlen && (j < inlen || ((j == inlen) && (bits > 0)))) {
448 const uint8_t nbits = std::min(
static_cast<uint8_t
>(lsb - b), bits);
449 const uint16_t mask =
static_cast<uint16_t
>(1 << nbits) - 1;
450 uint8_t t = (w >> (bits - nbits)) & mask;
452 elements.at(i) = elements.at(i) +
static_cast<uint16_t
>(t << (lsb - b - nbits));
455 w &=
static_cast<uint8_t
>(~(mask << bits));
476 auto elements = make_elements_vector(
dimensions);
484 if(constants.
d() <
sizeof(
decltype(m_elements)::value_type) * 8) {
485 const uint16_t mask =
static_cast<uint16_t
>(1 << constants.
d()) - 1;
486 for(
auto& elem : m_elements) {
#define BOTAN_ASSERT_NOMSG(expr)
#define BOTAN_ASSERT(expr, assertion_made)
#define BOTAN_ASSERT_UNREACHABLE()
Helper class to ease in-place marshalling of concatenated fixed-length values.
static constexpr Mask< T > expand_bit(T v, size_t bit)
static constexpr Mask< T > is_lt(T x, T y)
uint16_t cdf_table_at(size_t i) const
size_t cdf_table_len() const
static FrodoMatrix mul_add_sa_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static std::function< FrodoMatrix(const Dimensions &dimensions)> make_sample_generator(const FrodoKEMConstants &constants, Botan::XOF &shake)
static FrodoMatrix mul_add_sb_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &b, const FrodoMatrix &s, const FrodoMatrix &e)
void reduce(const FrodoKEMConstants &constants)
static FrodoMatrix mul_add_as_plus_e(const FrodoKEMConstants &constants, const FrodoMatrix &s, const FrodoMatrix &e, StrongSpan< const FrodoSeedA > seed_a)
static FrodoMatrix sub(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
FrodoPlaintext decode(const FrodoKEMConstants &constants) const
static FrodoMatrix add(const FrodoKEMConstants &constants, const FrodoMatrix &a, const FrodoMatrix &b)
std::tuple< size_t, size_t > Dimensions
static FrodoMatrix sample(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoSampleR > r)
CT::Mask< uint8_t > constant_time_compare(const FrodoMatrix &other) const
FrodoPackedMatrix pack(const FrodoKEMConstants &constants) const
static FrodoMatrix encode(const FrodoKEMConstants &constants, StrongSpan< const FrodoPlaintext > in)
Dimensions dimensions() const
FrodoMatrix(Dimensions dims)
static FrodoMatrix unpack(const FrodoKEMConstants &constants, const Dimensions &dimensions, StrongSpan< const FrodoPackedMatrix > packed_bytes)
uint16_t elements_at(size_t i) const
size_t element_count() const
static FrodoMatrix mul_bs(const FrodoKEMConstants &constants, const FrodoMatrix &b_p, const FrodoMatrix &s)
static FrodoMatrix deserialize(const Dimensions &dimensions, StrongSpan< const FrodoSerializedMatrix > bytes)
size_t packed_size(const FrodoKEMConstants &constants) const
FrodoSerializedMatrix serialize() const
decltype(auto) data() noexcept(noexcept(this->m_span.data()))
decltype(auto) size() const noexcept(noexcept(this->m_span.size()))
decltype(auto) data() noexcept(noexcept(this->get().data()))
constexpr T value_barrier(T x)
constexpr CT::Mask< T > is_equal(const T x[], const T y[], size_t len)
Strong< secure_vector< uint8_t >, struct FrodoSerializedMatrix_ > FrodoSerializedMatrix
auto create_shake_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
constexpr auto store_le(ParamTs &&... params)
auto create_aes_row_generator(const FrodoKEMConstants &constants, StrongSpan< const FrodoSeedA > seed_a)
BOTAN_FORCE_INLINE constexpr T ceil_tobytes(T bits)
constexpr auto load_le(ParamTs &&... params)
std::vector< T, secure_allocator< T > > secure_vector
Strong< secure_vector< uint8_t >, struct FrodoPlaintext_ > FrodoPlaintext
Strong< secure_vector< uint8_t >, struct FrodoSampleR_ > FrodoSampleR