Attacking a Discrete Knapsack Public Key Cryptosystem using LLL lattice reduction¶
In this example we will use our library to construct a Discrete Knapsack Public Key Cryptosystem and then attack it and retrieve private key using LLL lattice reduction
Discrete Knapsack Public Key Cryptosystem¶
In this cryptosystem we assume security by using Discrete Knapsack Packing Problem which is NP-hard.
In theory this should a be very strong assurance. However in practice in order to create a trapdoor we will have to use special kind of sequences (superincreasing sequences), that will allow us to find and exploit vulnerabilities of the cryptosystem without actually solving the underlying general NP-hard problem.
Discrete Knapsack Packing Problem¶
Given a list of positive integers $(M_1, M_2, ..., M_n)$ and another integer $S$ find a subset of the elements in the list that sums to $S$.
This problem in general is very hard to solve. However if we use so called superincreasing sequences as our list, we can use straightforward greedy algorithm to find a solution.
We say that list of positive integers $(r_1, r_2, ... , r_n)$ is superincreasing if
$$
r_{i+1} \geq 2r_i \quad \text{for all} \; 1 \leq i \leq n - 1
$$
If we are asked to solve knapsack problem $(M, S)$, where $M$ is integer sequence and $S$ is the sum we have to obtain then if $M$ is superincreasing sequence we can find a solution with following algorithm.
$$
\begin{aligned}
\text{1.} & \text{find the largest element of} \; M \; \text{that is smaller than} \; S. \newline
\text{2.} & \text{subtract this element from} \; S. \newline
\text{3.} & \text{repeat until the problem is solved}.
\end{aligned}
$$
The idea behind Discrete Knapsack Cryptosystems is quite simple. If we could construct superincreasing sequence and somehow obstruct it to the public eye, then it would appear that the best way to solve it is by using some kind of exponential algorithm. We could however easily find the solution by utilizing the greedy algorithm on unobstucted sequence.
Key creation¶
Let's choose some superincreasing sequence $\boldsymbol{r} = (r_1, ...., r_n)$ and two large secret integers $A$ and $B$ satisfying
$$
B > 2r_n \quad \text{and} \quad \gcd(A,B) = 1.
$$
We can then create new sequence $\boldsymbol{M}$ that is not superincreasing with the following method
$$
M_i \equiv Ar_i \mod B \quad \text{for all} \; 1 \leq i \leq n
$$
$(\boldsymbol{r}, A, B)$ is our private key and
$\boldsymbol{M}$ is our public key.
import math
import itertools
is_superincreasing = lambda xs: all((j >= 2 * i for i,j in itertools.pairwise(xs)))
r = [3,11,24,50,115]
A = 113
B = 250
assert is_superincreasing(r)
assert B > 2 * r[-1]
assert math.gcd(A,B) == 1
M = [(A * ri) % B for ri in r]
assert not is_superincreasing(M)
private_key = (r, A, B)
public_key = M
print("private key:", private_key)
print("public key:", public_key)
private key: ([3, 11, 24, 50, 115], 113, 250) public key: [89, 243, 212, 150, 245]
Encryption¶
Our plaintext has a form of binary vector $\boldsymbol{x}$
$i_{th}$ coefficient of the vector indicates if we count the $i_{th}$ sequence element into the knapsack sum.
The obtained sum $S$ is our ciphertext.
x = [1,0,1,0,1]
S = sum((xi * Mi for xi, Mi in zip(x, M) if xi))
plaintext = x
ciphertext = S
print("plaintext:", plaintext)
print("ciphertext:", ciphertext)
plaintext: [1, 0, 1, 0, 1] ciphertext: 546
Decryption¶
Upon receiving the sum $S$. We can transform it into "space" of our superincreasing sequence $$ S' \equiv A^{-1} S \mod{B} $$ Then we use the greedy algorithm on our secret superincreasing sequence to find the binary vector $\boldsymbol{y}$ that will be our decrypted message.
from lbpqc.primitives.integer.integer_ring import modinv
r, A, B = private_key
M = public_key
S = ciphertext
Sprim = (modinv(A, B) * S) % B
y = [0 for _ in r]
for i, ri in enumerate(r[::-1]):
if ri <= Sprim:
y[i] = 1
Sprim = Sprim - ri
y = y[::-1]
print("decrypted message:", y)
assert y == x
decrypted message: [1, 0, 1, 0, 1]
Attacking the cryptosystem¶
import numpy as np
from lbpqc.primitives.lattice import reductions
n = len(M)
A = np.identity(n + 1, dtype=int) * 2
A[-1] = 1
A[:-1,-1] = M
A[-1,-1] = S
A_LLL = reductions.LLL(A).astype(int)
w = A_LLL[0]
print("short vector in LLL reduced basis:")
print("w = ", w)
x = (w @ np.linalg.inv(A)).astype(int)
x, y = x[:-1], x[-1]
print(x)
print(np.dot(np.array(M), x))
print(-y * S)
short vector in LLL reduced basis: w = [-1 1 -1 1 -1 0] [-1 0 -1 0 -1] -546 -546