Botan  2.11.0
Crypto and TLS for C++11
mp_karat.cpp
Go to the documentation of this file.
1 /*
2 * Multiplication and Squaring
3 * (C) 1999-2010,2018 Jack Lloyd
4 * 2016 Matthias Gierlings
5 *
6 * Botan is released under the Simplified BSD License (see license.txt)
7 */
8 
9 #include <botan/internal/mp_core.h>
10 #include <botan/internal/mp_asmi.h>
11 #include <botan/internal/ct_utils.h>
12 #include <botan/mem_ops.h>
13 #include <botan/exceptn.h>
14 
15 namespace Botan {
16 
17 namespace {
18 
19 const size_t KARATSUBA_MULTIPLY_THRESHOLD = 32;
20 const size_t KARATSUBA_SQUARE_THRESHOLD = 32;
21 
22 /*
23 * Simple O(N^2) Multiplication
24 */
25 void basecase_mul(word z[], size_t z_size,
26  const word x[], size_t x_size,
27  const word y[], size_t y_size)
28  {
29  if(z_size < x_size + y_size)
30  throw Invalid_Argument("basecase_mul z_size too small");
31 
32  const size_t x_size_8 = x_size - (x_size % 8);
33 
34  clear_mem(z, z_size);
35 
36  for(size_t i = 0; i != y_size; ++i)
37  {
38  const word y_i = y[i];
39 
40  word carry = 0;
41 
42  for(size_t j = 0; j != x_size_8; j += 8)
43  carry = word8_madd3(z + i + j, x + j, y_i, carry);
44 
45  for(size_t j = x_size_8; j != x_size; ++j)
46  z[i+j] = word_madd3(x[j], y_i, z[i+j], &carry);
47 
48  z[x_size+i] = carry;
49  }
50  }
51 
52 void basecase_sqr(word z[], size_t z_size,
53  const word x[], size_t x_size)
54  {
55  if(z_size < 2*x_size)
56  throw Invalid_Argument("basecase_sqr z_size too small");
57 
58  const size_t x_size_8 = x_size - (x_size % 8);
59 
60  clear_mem(z, z_size);
61 
62  for(size_t i = 0; i != x_size; ++i)
63  {
64  const word x_i = x[i];
65 
66  word carry = 0;
67 
68  for(size_t j = 0; j != x_size_8; j += 8)
69  carry = word8_madd3(z + i + j, x + j, x_i, carry);
70 
71  for(size_t j = x_size_8; j != x_size; ++j)
72  z[i+j] = word_madd3(x[j], x_i, z[i+j], &carry);
73 
74  z[x_size+i] = carry;
75  }
76  }
77 
78 /*
79 * Karatsuba Multiplication Operation
80 */
81 void karatsuba_mul(word z[], const word x[], const word y[], size_t N,
82  word workspace[])
83  {
84  if(N < KARATSUBA_MULTIPLY_THRESHOLD || N % 2)
85  {
86  switch(N)
87  {
88  case 6:
89  return bigint_comba_mul6(z, x, y);
90  case 8:
91  return bigint_comba_mul8(z, x, y);
92  case 9:
93  return bigint_comba_mul9(z, x, y);
94  case 16:
95  return bigint_comba_mul16(z, x, y);
96  case 24:
97  return bigint_comba_mul24(z, x, y);
98  default:
99  return basecase_mul(z, 2*N, x, N, y, N);
100  }
101  }
102 
103  const size_t N2 = N / 2;
104 
105  const word* x0 = x;
106  const word* x1 = x + N2;
107  const word* y0 = y;
108  const word* y1 = y + N2;
109  word* z0 = z;
110  word* z1 = z + N;
111 
112  word* ws0 = workspace;
113  word* ws1 = workspace + N;
114 
115  clear_mem(workspace, 2*N);
116 
117  /*
118  * If either of cmp0 or cmp1 is zero then z0 or z1 resp is zero here,
119  * resulting in a no-op - z0*z1 will be equal to zero so we don't need to do
120  * anything, clear_mem above already set the correct result.
121  *
122  * However we ignore the result of the comparisons and always perform the
123  * subtractions and recursively multiply to avoid the timing channel.
124  */
125 
126  // First compute (X_lo - X_hi)*(Y_hi - Y_lo)
127  const auto cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace);
128  const auto cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace);
129  const auto neg_mask = ~(cmp0 ^ cmp1);
130 
131  karatsuba_mul(ws0, z0, z1, N2, ws1);
132 
133  // Compute X_lo * Y_lo
134  karatsuba_mul(z0, x0, y0, N2, ws1);
135 
136  // Compute X_hi * Y_hi
137  karatsuba_mul(z1, x1, y1, N2, ws1);
138 
139  const word ws_carry = bigint_add3_nc(ws1, z0, N, z1, N);
140  word z_carry = bigint_add2_nc(z + N2, N, ws1, N);
141 
142  z_carry += bigint_add2_nc(z + N + N2, N2, &ws_carry, 1);
143  bigint_add2_nc(z + N + N2, N2, &z_carry, 1);
144 
145  clear_mem(workspace + N, N2);
146 
147  bigint_cnd_add_or_sub(neg_mask, z + N2, workspace, 2*N-N2);
148  }
149 
150 /*
151 * Karatsuba Squaring Operation
152 */
153 void karatsuba_sqr(word z[], const word x[], size_t N, word workspace[])
154  {
155  if(N < KARATSUBA_SQUARE_THRESHOLD || N % 2)
156  {
157  switch(N)
158  {
159  case 6:
160  return bigint_comba_sqr6(z, x);
161  case 8:
162  return bigint_comba_sqr8(z, x);
163  case 9:
164  return bigint_comba_sqr9(z, x);
165  case 16:
166  return bigint_comba_sqr16(z, x);
167  case 24:
168  return bigint_comba_sqr24(z, x);
169  default:
170  return basecase_sqr(z, 2*N, x, N);
171  }
172  }
173 
174  const size_t N2 = N / 2;
175 
176  const word* x0 = x;
177  const word* x1 = x + N2;
178  word* z0 = z;
179  word* z1 = z + N;
180 
181  word* ws0 = workspace;
182  word* ws1 = workspace + N;
183 
184  clear_mem(workspace, 2*N);
185 
186  // See comment in karatsuba_mul
187  bigint_sub_abs(z0, x0, x1, N2, workspace);
188  karatsuba_sqr(ws0, z0, N2, ws1);
189 
190  karatsuba_sqr(z0, x0, N2, ws1);
191  karatsuba_sqr(z1, x1, N2, ws1);
192 
193  const word ws_carry = bigint_add3_nc(ws1, z0, N, z1, N);
194  word z_carry = bigint_add2_nc(z + N2, N, ws1, N);
195 
196  z_carry += bigint_add2_nc(z + N + N2, N2, &ws_carry, 1);
197  bigint_add2_nc(z + N + N2, N2, &z_carry, 1);
198 
199  /*
200  * This is only actually required if cmp (result of bigint_sub_abs) is != 0,
201  * however if cmp==0 then ws0[0:N] == 0 and avoiding the jump hides a
202  * timing channel.
203  */
204  bigint_sub2(z + N2, 2*N-N2, ws0, N);
205  }
206 
207 /*
208 * Pick a good size for the Karatsuba multiply
209 */
210 size_t karatsuba_size(size_t z_size,
211  size_t x_size, size_t x_sw,
212  size_t y_size, size_t y_sw)
213  {
214  if(x_sw > x_size || x_sw > y_size || y_sw > x_size || y_sw > y_size)
215  return 0;
216 
217  if(((x_size == x_sw) && (x_size % 2)) ||
218  ((y_size == y_sw) && (y_size % 2)))
219  return 0;
220 
221  const size_t start = (x_sw > y_sw) ? x_sw : y_sw;
222  const size_t end = (x_size < y_size) ? x_size : y_size;
223 
224  if(start == end)
225  {
226  if(start % 2)
227  return 0;
228  return start;
229  }
230 
231  for(size_t j = start; j <= end; ++j)
232  {
233  if(j % 2)
234  continue;
235 
236  if(2*j > z_size)
237  return 0;
238 
239  if(x_sw <= j && j <= x_size && y_sw <= j && j <= y_size)
240  {
241  if(j % 4 == 2 &&
242  (j+2) <= x_size && (j+2) <= y_size && 2*(j+2) <= z_size)
243  return j+2;
244  return j;
245  }
246  }
247 
248  return 0;
249  }
250 
251 /*
252 * Pick a good size for the Karatsuba squaring
253 */
254 size_t karatsuba_size(size_t z_size, size_t x_size, size_t x_sw)
255  {
256  if(x_sw == x_size)
257  {
258  if(x_sw % 2)
259  return 0;
260  return x_sw;
261  }
262 
263  for(size_t j = x_sw; j <= x_size; ++j)
264  {
265  if(j % 2)
266  continue;
267 
268  if(2*j > z_size)
269  return 0;
270 
271  if(j % 4 == 2 && (j+2) <= x_size && 2*(j+2) <= z_size)
272  return j+2;
273  return j;
274  }
275 
276  return 0;
277  }
278 
279 template<size_t SZ>
280 inline bool sized_for_comba_mul(size_t x_sw, size_t x_size,
281  size_t y_sw, size_t y_size,
282  size_t z_size)
283  {
284  return (x_sw <= SZ && x_size >= SZ &&
285  y_sw <= SZ && y_size >= SZ &&
286  z_size >= 2*SZ);
287  }
288 
289 template<size_t SZ>
290 inline bool sized_for_comba_sqr(size_t x_sw, size_t x_size,
291  size_t z_size)
292  {
293  return (x_sw <= SZ && x_size >= SZ && z_size >= 2*SZ);
294  }
295 
296 }
297 
298 void bigint_mul(word z[], size_t z_size,
299  const word x[], size_t x_size, size_t x_sw,
300  const word y[], size_t y_size, size_t y_sw,
301  word workspace[], size_t ws_size)
302  {
303  clear_mem(z, z_size);
304 
305  if(x_sw == 1)
306  {
307  bigint_linmul3(z, y, y_sw, x[0]);
308  }
309  else if(y_sw == 1)
310  {
311  bigint_linmul3(z, x, x_sw, y[0]);
312  }
313  else if(sized_for_comba_mul<4>(x_sw, x_size, y_sw, y_size, z_size))
314  {
315  bigint_comba_mul4(z, x, y);
316  }
317  else if(sized_for_comba_mul<6>(x_sw, x_size, y_sw, y_size, z_size))
318  {
319  bigint_comba_mul6(z, x, y);
320  }
321  else if(sized_for_comba_mul<8>(x_sw, x_size, y_sw, y_size, z_size))
322  {
323  bigint_comba_mul8(z, x, y);
324  }
325  else if(sized_for_comba_mul<9>(x_sw, x_size, y_sw, y_size, z_size))
326  {
327  bigint_comba_mul9(z, x, y);
328  }
329  else if(sized_for_comba_mul<16>(x_sw, x_size, y_sw, y_size, z_size))
330  {
331  bigint_comba_mul16(z, x, y);
332  }
333  else if(sized_for_comba_mul<24>(x_sw, x_size, y_sw, y_size, z_size))
334  {
335  bigint_comba_mul24(z, x, y);
336  }
337  else if(x_sw < KARATSUBA_MULTIPLY_THRESHOLD ||
338  y_sw < KARATSUBA_MULTIPLY_THRESHOLD ||
339  !workspace)
340  {
341  basecase_mul(z, z_size, x, x_sw, y, y_sw);
342  }
343  else
344  {
345  const size_t N = karatsuba_size(z_size, x_size, x_sw, y_size, y_sw);
346 
347  if(N && z_size >= 2*N && ws_size >= 2*N)
348  karatsuba_mul(z, x, y, N, workspace);
349  else
350  basecase_mul(z, z_size, x, x_sw, y, y_sw);
351  }
352  }
353 
354 /*
355 * Squaring Algorithm Dispatcher
356 */
357 void bigint_sqr(word z[], size_t z_size,
358  const word x[], size_t x_size, size_t x_sw,
359  word workspace[], size_t ws_size)
360  {
361  clear_mem(z, z_size);
362 
363  BOTAN_ASSERT(z_size/2 >= x_sw, "Output size is sufficient");
364 
365  if(x_sw == 1)
366  {
367  bigint_linmul3(z, x, x_sw, x[0]);
368  }
369  else if(sized_for_comba_sqr<4>(x_sw, x_size, z_size))
370  {
371  bigint_comba_sqr4(z, x);
372  }
373  else if(sized_for_comba_sqr<6>(x_sw, x_size, z_size))
374  {
375  bigint_comba_sqr6(z, x);
376  }
377  else if(sized_for_comba_sqr<8>(x_sw, x_size, z_size))
378  {
379  bigint_comba_sqr8(z, x);
380  }
381  else if(sized_for_comba_sqr<9>(x_sw, x_size, z_size))
382  {
383  bigint_comba_sqr9(z, x);
384  }
385  else if(sized_for_comba_sqr<16>(x_sw, x_size, z_size))
386  {
387  bigint_comba_sqr16(z, x);
388  }
389  else if(sized_for_comba_sqr<24>(x_sw, x_size, z_size))
390  {
391  bigint_comba_sqr24(z, x);
392  }
393  else if(x_size < KARATSUBA_SQUARE_THRESHOLD || !workspace)
394  {
395  basecase_sqr(z, z_size, x, x_sw);
396  }
397  else
398  {
399  const size_t N = karatsuba_size(z_size, x_size, x_sw);
400 
401  if(N && z_size >= 2*N && ws_size >= 2*N)
402  karatsuba_sqr(z, x, N, workspace);
403  else
404  basecase_sqr(z, z_size, x, x_sw);
405  }
406  }
407 
408 }
void carry(int64_t &h0, int64_t &h1)
void clear_mem(T *ptr, size_t n)
Definition: mem_ops.h:111
word bigint_sub2(word x[], size_t x_size, const word y[], size_t y_size)
Definition: mp_core.h:302
word bigint_add3_nc(word z[], const word x[], size_t x_size, const word y[], size_t y_size)
Definition: mp_core.h:252
word word_madd3(word a, word b, word c, word *d)
Definition: mp_madd.h:94
void bigint_comba_mul4(word z[8], const word x[4], const word y[4])
Definition: mp_comba.cpp:50
void bigint_comba_mul9(word z[18], const word x[9], const word y[9])
Definition: mp_comba.cpp:474
void bigint_sqr(word z[], size_t z_size, const word x[], size_t x_size, size_t x_sw, word workspace[], size_t ws_size)
Definition: mp_karat.cpp:357
void bigint_comba_mul24(word z[48], const word x[24], const word y[24])
Definition: mp_comba.cpp:1535
#define BOTAN_ASSERT(expr, assertion_made)
Definition: assert.h:55
void bigint_linmul3(word z[], const word x[], size_t x_size, word y)
Definition: mp_core.h:504
void bigint_comba_sqr16(word z[32], const word x[16])
Definition: mp_comba.cpp:598
word word8_madd3(word z[8], const word x[8], word y, word carry)
Definition: mp_asmi.h:681
CT::Mask< word > bigint_sub_abs(word z[], const word x[], const word y[], size_t N, word ws[])
Definition: mp_core.h:379
void bigint_comba_sqr24(word z[48], const word x[24])
Definition: mp_comba.cpp:1132
word bigint_add2_nc(word x[], size_t x_size, const word y[], size_t y_size)
Definition: mp_core.h:229
void bigint_cnd_add_or_sub(CT::Mask< word > mask, word x[], const word y[], size_t size)
Definition: mp_core.h:141
Definition: alg_id.cpp:13
void bigint_comba_sqr9(word z[18], const word x[9])
Definition: mp_comba.cpp:386
void bigint_mul(word z[], size_t z_size, const word x[], size_t x_size, size_t x_sw, const word y[], size_t y_size, size_t y_sw, word workspace[], size_t ws_size)
Definition: mp_karat.cpp:298
void bigint_comba_mul8(word z[16], const word x[8], const word y[8])
Definition: mp_comba.cpp:283
void bigint_comba_sqr8(word z[16], const word x[8])
Definition: mp_comba.cpp:208
void bigint_comba_mul16(word z[32], const word x[16], const word y[16])
Definition: mp_comba.cpp:805
void bigint_comba_mul6(word z[12], const word x[6], const word y[6])
Definition: mp_comba.cpp:141
void bigint_comba_sqr4(word z[8], const word x[4])
Definition: mp_comba.cpp:17
void bigint_comba_sqr6(word z[12], const word x[6])
Definition: mp_comba.cpp:89