Botan 3.11.0
Crypto and TLS for C&
sm4_avx512.cpp
Go to the documentation of this file.
1/*
2* (C) 2025 Jack Lloyd
3*
4* Botan is released under the Simplified BSD License (see license.txt)
5*/
6
7#include <botan/internal/sm4.h>
8
9#include <botan/mem_ops.h>
10#include <botan/internal/isa_extn.h>
11#include <botan/internal/simd_avx2_gfni.h>
12#include <botan/internal/simd_avx512.h>
13
14namespace Botan {
15
16namespace SM4_AVX512_GFNI {
17
18namespace {
19
20template <uint64_t A, uint8_t B>
21BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI SIMD_16x32 gf2p8affine(const SIMD_16x32& x) {
22 return SIMD_16x32(_mm512_gf2p8affine_epi64_epi8(x.raw(), _mm512_set1_epi64(A), B));
23}
24
25template <uint64_t A, uint8_t B>
26BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI SIMD_16x32 gf2p8affineinv(const SIMD_16x32& x) {
27 return SIMD_16x32(_mm512_gf2p8affineinv_epi64_epi8(x.raw(), _mm512_set1_epi64(A), B));
28}
29
30template <typename SIMD_T>
31BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI SIMD_T sm4_sbox(const SIMD_T& x) {
32 /*
33 * See https://eprint.iacr.org/2022/1154 section 3.3 for details on
34 * how this works
35 */
36 constexpr uint64_t pre_a = gfni_matrix(R"(
37 0 0 1 1 0 0 1 0
38 0 0 0 1 0 1 0 0
39 1 0 1 1 1 1 1 0
40 1 0 0 1 1 1 0 1
41 0 1 0 1 1 0 0 0
42 0 1 0 0 0 1 0 0
43 0 0 0 0 1 0 1 0
44 1 0 1 1 1 0 1 0)");
45
46 constexpr uint8_t pre_c = 0b00111110;
47
48 constexpr uint64_t post_a = gfni_matrix(R"(
49 1 1 0 0 1 1 1 1
50 1 1 0 1 0 1 0 1
51 0 0 1 0 1 1 0 0
52 1 0 0 1 0 1 0 1
53 0 0 1 0 1 1 1 0
54 0 1 1 0 0 1 0 1
55 1 0 1 0 1 1 0 1
56 1 0 0 1 0 0 0 1)");
57
58 constexpr uint8_t post_c = 0b11010011;
59
60 auto y = gf2p8affine<pre_a, pre_c>(x);
62}
63
64template <typename SIMD_T>
65BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI SIMD_T sm4_f(const SIMD_T& x) {
66 const auto sx = sm4_sbox(x);
67 return sx ^ sx.template rotl<2>() ^ sx.template rotl<10>() ^ sx.template rotl<18>() ^ sx.template rotl<24>();
68}
69
70template <typename SIMD_T, size_t M>
71BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI void encrypt(const uint8_t ptext[16 * 4 * M],
72 uint8_t ctext[16 * 4 * M],
73 std::span<const uint32_t> RK) {
74 SIMD_T B0 = SIMD_T::load_be(ptext);
75 SIMD_T B1 = SIMD_T::load_be(ptext + 16 * M);
76 SIMD_T B2 = SIMD_T::load_be(ptext + 16 * 2 * M);
77 SIMD_T B3 = SIMD_T::load_be(ptext + 16 * 3 * M);
78
79 SIMD_T::transpose(B0, B1, B2, B3);
80
81 B0 = B0.rev_words();
82 B1 = B1.rev_words();
83 B2 = B2.rev_words();
84 B3 = B3.rev_words();
85
86 for(size_t j = 0; j != 8; ++j) {
87 B0 ^= sm4_f(B1 ^ B2 ^ B3 ^ SIMD_T::splat(RK[4 * j]));
88 B1 ^= sm4_f(B2 ^ B3 ^ B0 ^ SIMD_T::splat(RK[4 * j + 1]));
89 B2 ^= sm4_f(B3 ^ B0 ^ B1 ^ SIMD_T::splat(RK[4 * j + 2]));
90 B3 ^= sm4_f(B0 ^ B1 ^ B2 ^ SIMD_T::splat(RK[4 * j + 3]));
91 }
92
93 SIMD_T::transpose(B0, B1, B2, B3);
94
95 B3.rev_words().store_be(ctext);
96 B2.rev_words().store_be(ctext + 16 * M);
97 B1.rev_words().store_be(ctext + 16 * 2 * M);
98 B0.rev_words().store_be(ctext + 16 * 3 * M);
99}
100
101template <typename SIMD_T, size_t M>
102BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI void encrypt_x2(const uint8_t ptext[32 * 4 * M],
103 uint8_t ctext[32 * 4 * M],
104 std::span<const uint32_t> RK) {
105 SIMD_T B0 = SIMD_T::load_be(ptext);
106 SIMD_T B1 = SIMD_T::load_be(ptext + 16 * M);
107 SIMD_T B2 = SIMD_T::load_be(ptext + 16 * 2 * M);
108 SIMD_T B3 = SIMD_T::load_be(ptext + 16 * 3 * M);
109
110 SIMD_T B4 = SIMD_T::load_be(ptext + 16 * 4 * M);
111 SIMD_T B5 = SIMD_T::load_be(ptext + 16 * 5 * M);
112 SIMD_T B6 = SIMD_T::load_be(ptext + 16 * 6 * M);
113 SIMD_T B7 = SIMD_T::load_be(ptext + 16 * 7 * M);
114
115 SIMD_T::transpose(B0, B1, B2, B3);
116 SIMD_T::transpose(B4, B5, B6, B7);
117
118 B0 = B0.rev_words();
119 B1 = B1.rev_words();
120 B2 = B2.rev_words();
121 B3 = B3.rev_words();
122
123 B4 = B4.rev_words();
124 B5 = B5.rev_words();
125 B6 = B6.rev_words();
126 B7 = B7.rev_words();
127
128 for(size_t j = 0; j != 8; ++j) {
129 B0 ^= sm4_f(B1 ^ B2 ^ B3 ^ SIMD_T::splat(RK[4 * j]));
130 B4 ^= sm4_f(B5 ^ B6 ^ B7 ^ SIMD_T::splat(RK[4 * j]));
131
132 B1 ^= sm4_f(B2 ^ B3 ^ B0 ^ SIMD_T::splat(RK[4 * j + 1]));
133 B5 ^= sm4_f(B6 ^ B7 ^ B4 ^ SIMD_T::splat(RK[4 * j + 1]));
134
135 B2 ^= sm4_f(B3 ^ B0 ^ B1 ^ SIMD_T::splat(RK[4 * j + 2]));
136 B6 ^= sm4_f(B7 ^ B4 ^ B5 ^ SIMD_T::splat(RK[4 * j + 2]));
137
138 B3 ^= sm4_f(B0 ^ B1 ^ B2 ^ SIMD_T::splat(RK[4 * j + 3]));
139 B7 ^= sm4_f(B4 ^ B5 ^ B6 ^ SIMD_T::splat(RK[4 * j + 3]));
140 }
141
142 SIMD_T::transpose(B0, B1, B2, B3);
143 SIMD_T::transpose(B4, B5, B6, B7);
144
145 B3.rev_words().store_be(ctext);
146 B2.rev_words().store_be(ctext + 16 * M);
147 B1.rev_words().store_be(ctext + 16 * 2 * M);
148 B0.rev_words().store_be(ctext + 16 * 3 * M);
149
150 B7.rev_words().store_be(ctext + 16 * 4 * M);
151 B6.rev_words().store_be(ctext + 16 * 5 * M);
152 B5.rev_words().store_be(ctext + 16 * 6 * M);
153 B4.rev_words().store_be(ctext + 16 * 7 * M);
154}
155
156template <typename SIMD_T, size_t M>
157BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI void decrypt(const uint8_t ctext[16 * 4 * M],
158 uint8_t ptext[16 * 4 * M],
159 std::span<const uint32_t> RK) {
160 SIMD_T B0 = SIMD_T::load_be(ctext);
161 SIMD_T B1 = SIMD_T::load_be(ctext + 16 * M);
162 SIMD_T B2 = SIMD_T::load_be(ctext + 16 * 2 * M);
163 SIMD_T B3 = SIMD_T::load_be(ctext + 16 * 3 * M);
164
165 SIMD_T::transpose(B0, B1, B2, B3);
166
167 B0 = B0.rev_words();
168 B1 = B1.rev_words();
169 B2 = B2.rev_words();
170 B3 = B3.rev_words();
171
172 for(size_t j = 0; j != 8; ++j) {
173 B0 ^= sm4_f(B1 ^ B2 ^ B3 ^ SIMD_T::splat(RK[32 - (4 * j + 1)]));
174 B1 ^= sm4_f(B2 ^ B3 ^ B0 ^ SIMD_T::splat(RK[32 - (4 * j + 2)]));
175 B2 ^= sm4_f(B3 ^ B0 ^ B1 ^ SIMD_T::splat(RK[32 - (4 * j + 3)]));
176 B3 ^= sm4_f(B0 ^ B1 ^ B2 ^ SIMD_T::splat(RK[32 - (4 * j + 4)]));
177 }
178
179 SIMD_T::transpose(B0, B1, B2, B3);
180
181 B3.rev_words().store_be(ptext);
182 B2.rev_words().store_be(ptext + 16 * M);
183 B1.rev_words().store_be(ptext + 16 * 2 * M);
184 B0.rev_words().store_be(ptext + 16 * 3 * M);
185}
186
187template <typename SIMD_T, size_t M>
188BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX512_GFNI void decrypt_x2(const uint8_t ctext[32 * 4 * M],
189 uint8_t ptext[32 * 4 * M],
190 std::span<const uint32_t> RK) {
191 SIMD_T B0 = SIMD_T::load_be(ctext);
192 SIMD_T B1 = SIMD_T::load_be(ctext + 16 * M);
193 SIMD_T B2 = SIMD_T::load_be(ctext + 16 * 2 * M);
194 SIMD_T B3 = SIMD_T::load_be(ctext + 16 * 3 * M);
195
196 SIMD_T B4 = SIMD_T::load_be(ctext + 16 * 4 * M);
197 SIMD_T B5 = SIMD_T::load_be(ctext + 16 * 5 * M);
198 SIMD_T B6 = SIMD_T::load_be(ctext + 16 * 6 * M);
199 SIMD_T B7 = SIMD_T::load_be(ctext + 16 * 7 * M);
200
201 SIMD_T::transpose(B0, B1, B2, B3);
202 SIMD_T::transpose(B4, B5, B6, B7);
203
204 B0 = B0.rev_words();
205 B1 = B1.rev_words();
206 B2 = B2.rev_words();
207 B3 = B3.rev_words();
208
209 B4 = B4.rev_words();
210 B5 = B5.rev_words();
211 B6 = B6.rev_words();
212 B7 = B7.rev_words();
213
214 for(size_t j = 0; j != 8; ++j) {
215 B0 ^= sm4_f(B1 ^ B2 ^ B3 ^ SIMD_T::splat(RK[32 - (4 * j + 1)]));
216 B4 ^= sm4_f(B5 ^ B6 ^ B7 ^ SIMD_T::splat(RK[32 - (4 * j + 1)]));
217
218 B1 ^= sm4_f(B2 ^ B3 ^ B0 ^ SIMD_T::splat(RK[32 - (4 * j + 2)]));
219 B5 ^= sm4_f(B6 ^ B7 ^ B4 ^ SIMD_T::splat(RK[32 - (4 * j + 2)]));
220
221 B2 ^= sm4_f(B3 ^ B0 ^ B1 ^ SIMD_T::splat(RK[32 - (4 * j + 3)]));
222 B6 ^= sm4_f(B7 ^ B4 ^ B5 ^ SIMD_T::splat(RK[32 - (4 * j + 3)]));
223
224 B3 ^= sm4_f(B0 ^ B1 ^ B2 ^ SIMD_T::splat(RK[32 - (4 * j + 4)]));
225 B7 ^= sm4_f(B4 ^ B5 ^ B6 ^ SIMD_T::splat(RK[32 - (4 * j + 4)]));
226 }
227
228 SIMD_T::transpose(B0, B1, B2, B3);
229 SIMD_T::transpose(B4, B5, B6, B7);
230
231 B3.rev_words().store_be(ptext);
232 B2.rev_words().store_be(ptext + 16 * M);
233 B1.rev_words().store_be(ptext + 16 * 2 * M);
234 B0.rev_words().store_be(ptext + 16 * 3 * M);
235
236 B7.rev_words().store_be(ptext + 16 * 4 * M);
237 B6.rev_words().store_be(ptext + 16 * 5 * M);
238 B5.rev_words().store_be(ptext + 16 * 6 * M);
239 B4.rev_words().store_be(ptext + 16 * 7 * M);
240}
241
242} // namespace
243
244} // namespace SM4_AVX512_GFNI
245
246void BOTAN_FN_ISA_AVX512_GFNI SM4::sm4_avx512_gfni_encrypt(const uint8_t ptext[],
247 uint8_t ctext[],
248 size_t blocks) const {
249 while(blocks >= 32) {
250 SM4_AVX512_GFNI::encrypt_x2<SIMD_16x32, 4>(ptext, ctext, m_RK);
251 ptext += 16 * 32;
252 ctext += 16 * 32;
253 blocks -= 32;
254 }
255
256 while(blocks >= 16) {
257 SM4_AVX512_GFNI::encrypt<SIMD_16x32, 4>(ptext, ctext, m_RK);
258 ptext += 16 * 16;
259 ctext += 16 * 16;
260 blocks -= 16;
261 }
262
263 while(blocks >= 8) {
264 SM4_AVX512_GFNI::encrypt<SIMD_8x32, 2>(ptext, ctext, m_RK);
265 ptext += 16 * 8;
266 ctext += 16 * 8;
267 blocks -= 8;
268 }
269
270 if(blocks > 0) {
271 uint8_t pbuf[16 * 8] = {0};
272 uint8_t cbuf[16 * 8] = {0};
273 copy_mem(pbuf, ptext, blocks * 16);
274 SM4_AVX512_GFNI::encrypt<SIMD_8x32, 2>(pbuf, cbuf, m_RK);
275 copy_mem(ctext, cbuf, blocks * 16);
276 }
277}
278
279void BOTAN_FN_ISA_AVX512_GFNI SM4::sm4_avx512_gfni_decrypt(const uint8_t ctext[],
280 uint8_t ptext[],
281 size_t blocks) const {
282 while(blocks >= 32) {
283 SM4_AVX512_GFNI::decrypt_x2<SIMD_16x32, 4>(ctext, ptext, m_RK);
284 ptext += 16 * 32;
285 ctext += 16 * 32;
286 blocks -= 32;
287 }
288
289 while(blocks >= 16) {
290 SM4_AVX512_GFNI::decrypt<SIMD_16x32, 4>(ctext, ptext, m_RK);
291 ptext += 16 * 16;
292 ctext += 16 * 16;
293 blocks -= 16;
294 }
295
296 while(blocks >= 8) {
297 SM4_AVX512_GFNI::decrypt<SIMD_8x32, 2>(ctext, ptext, m_RK);
298 ptext += 16 * 8;
299 ctext += 16 * 8;
300 blocks -= 8;
301 }
302
303 if(blocks > 0) {
304 uint8_t cbuf[16 * 8] = {0};
305 uint8_t pbuf[16 * 8] = {0};
306 copy_mem(cbuf, ctext, blocks * 16);
307 SM4_AVX512_GFNI::decrypt<SIMD_8x32, 2>(cbuf, pbuf, m_RK);
308 copy_mem(ptext, pbuf, blocks * 16);
309 }
310}
311
312} // namespace Botan
__m512i BOTAN_FN_ISA_AVX512 raw() const
#define BOTAN_FORCE_INLINE
Definition compiler.h:87
constexpr void copy_mem(T *out, const T *in, size_t n)
Definition mem_ops.h:144
consteval uint64_t gfni_matrix(std::string_view s)
BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX2_GFNI SIMD_8x32 gf2p8affineinv(const SIMD_8x32 &x)
BOTAN_FORCE_INLINE constexpr T rotl(T input)
Definition rotate.h:23
BOTAN_FORCE_INLINE BOTAN_FN_ISA_AVX2_GFNI SIMD_8x32 gf2p8affine(const SIMD_8x32 &x)