Cryptography
RSA Wiener attack code
chltjdbs
2024. 2. 17. 12:44
wiener attack 이란 단순히 말해 e 값이 클때 사용할수 있는 공격이다.
대게 확률적으로 e값이 증폭적으로 크다면 d값이 작을 확률이 매우 높기때문에
wiener 공격을 하면 d를 알아낼수있어 RSA 취약점인 개인키가 공개되기 때문에 이를 이용하여 flag를 알아낼수있다
RSA에서 텍스트파일로 저장된 Public key.pem 파일을 읽어와서 키 파일을 얻어오는 코드와
wiener 공격을 한 파일에 담아 정리했다
https://github.com/pablocelayes/rsa-wiener-attack/blob/master/RSAwienerHacker.py
E,n값과 public key 텍스트 파일을 입력해 긁어오는 방식
키를 직접 추가 하고 싶다면 키를 가지고 오는 코드를 삭제후 본인이 입력하면 제대로 동작한다.
from Crypto.Util.number import*
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from base64 import b64encode
# Get public key from a file
pubkey_file = open("public_key.pem", 'r')
pubkey = pubkey_file.read()
pubkey_file.close()
flag_file = open("flag.txt", 'r')
flag = int(flag_file.read(), 16)
flag_file.close()
msg = "plain text"
# Create key using public key
keyPub = RSA.importKey(pubkey)
cipher = Cipher_PKCS1_v1_5.new(keyPub)
cipher_text = cipher.encrypt(msg.encode())
emsg = b64encode(cipher_text)
encryptedText = emsg.decode('utf-8')
n = keyPub.n
e = keyPub.e
def convergents_from_contfrac(frac):
convs = [];
for i in range(len(frac)):
convs.append(contfrac_to_rational(frac[0:i]))
return convs
def contfrac_to_rational (frac):
if len(frac) == 0:
return (0,1)
num = frac[-1]
denom = 1
for _ in range(-2,-len(frac)-1,-1):
num, denom = frac[_]*num+denom, num
return (num,denom)
def is_perfect_square(n):
h = n & 0xF;
if h > 9:
return -1
if ( h != 2 and h != 3 and h != 5 and h != 6 and h != 7 and h != 8 ):
# take square root if you must
t = isqrt(n)
if t*t == n:
return t
else:
return -1
return -1
def test_is_perfect_square():
print("Testing is_perfect_square")
testsuit = [4, 0, 15, 25, 18, 901, 1000, 1024]
for n in testsuit:
print("Is ", n, " a perfect square?")
if is_perfect_square(n)!= -1:
print("Yes!")
else:
print("Nope")
def isqrt(n):
if n < 0:
raise ValueError('square root not defined for negative numbers')
if n == 0:
return 0
a, b = divmod(bitlength(n), 2)
x = 2**(a+b)
while True:
y = (x + n//x)//2
if y >= x:
return x
x = y
def bitlength(x):
assert x >= 0
n = 0
while x > 0:
n = n+1
x = x>>1
return n
def hack_RSA(e,n):
# rational_to_contfrac ref
x, y = e, n
a = x//y
pquotients = [a]
while a * y != x:
x,y = y,x-a*y
a = x//y
pquotients.append(a)
frac = pquotients
# rational_to_contfrac ref end
convergents = convergents_from_contfrac(frac)
for (k,d) in convergents:
#check if d is actually the key
if k!=0 and (e*d-1)%k == 0:
phi = (e*d-1)//k
s = n - phi + 1
discr = s*s - 4*n
if(discr>=0):
t = is_perfect_square(discr)
if t!=-1 and (s+t)%2==0:
print("Hacked!")
return d
d = hack_RSA(e,n)
pt = pow(flag,d,n)
print(long_to_bytes(pt))