################################################## # ent.py -- Element Number Theory # (c) William Stein, 2004 ################################################## from random import randrange from math import log, sqrt from imp import reload ################################################## ## Greatest Common Divisors ################################################## def gcd(a, b): # (1) """ Returns the greatest commond divisor of a and b. Input: a -- an integer b -- an integer Output: an integer, the gcd of a and b """ a,b = abs(a),abs(b) while b != 0: (a, b) = (b, a%b) return a def sign(n): """ Returns -1 if n<0 and +1 otherwise. """ if n<0: return -1 return 1 def xgcd(a, b): """ Returns g, x, y such that g = x*a + y*b = gcd(a,b). Input: a -- an integer b -- an integer Output: g -- an integer, the gcd of a and b x -- an integer y -- an integer """ a,x_sign = abs(a),sign(a) b,y_sign = abs(b),sign(b) if a == 0: return (b,0,y_sign) if b == 0: return (a,x_sign,0) x = 1; y = 0; r = 0; s = 1 while b != 0: q = a//b (a, x, y, b, r, s) = (b, r, s, a-q*b, x-q*r, y-q*s) return (a, x*x_sign, y*y_sign) def inversemod(a, n): """ Returns the inverse of a modulo n, normalized to lie between 0 and n-1. If a is not coprime to n, raise an exception (this will be useful later for the elliptic curve factorization method). Input: a -- an integer coprime to n n -- a positive integer Output: an integer between 0 and n-1. """ g, x, y = xgcd(a, n) if g != 1: raise(ZeroDivisionError, (a,n)) assert g == 1, "a must be coprime to n." return x%n ################################################## ## Determining Whether a Number is Prime ################################################## def miller_rabin(n, num_trials=10): """ True if n is likely prime, and False if n is definitely not prime. Increasing num_trials increases the probability of correctness. (One can prove that the probability that this function returns True when it should return False is at most (1/4)**num_trials.) Input: n -- an integer num_trials -- the number of trials with the primality test. Output: bool -- True or False """ if n < 0: n = -n if n in [2,3]: return True if n <= 4: return False m = n - 1 k = 0 while m%2 == 0: k += 1; m //= 2 # Now n - 1 = (2**k) * m with m odd for i in range(num_trials): a = randrange(2,n-1) apow = pow(a, m, n) if not (apow in [1, n-1]): some_minus_one = False for r in range(k-1): apow = (apow**2)%n if apow == n-1: some_minus_one = True break if (apow in [1, n-1]) or some_minus_one: prob_prime = True else: return False return True def random_prime(num_digits, is_prime = miller_rabin): """ Returns a random prime with num_digits digits. Input: num_digits -- a positive integer is_prime -- (optional argment) a function of one argument n that returns either True if n is (probably) prime and False otherwise. Output: int -- an integer """ n = randrange(10**(num_digits-1), 10**num_digits) if n%2 == 0: n += 1 while not is_prime(n): n += 2 return n def random_nice_prime(num_digits, mod_four=None): """ Returns a random prime congruent to 1 mod 4 with num_digits digits. Input: num_digits -- a positive integer mod_four -- (optional argument) an integer that will equal the residue (either 1 or 3) of the returned prime modulo 4. Output: int -- a probable prime congruent to mod_four mod 4 """ # is_prime is a function of one argument n that returns # either True if n is (probably) prime and False otherwise. # I choose miller_rabin since it is defined in this file. is_prime = miller_rabin n = randrange(10**(num_digits-1), 10**num_digits) if not mod_four: shift = 2 n += (n-1)%2 else: shift = 4 mod_four %= 4 assert mod_four in (1,3), "p must be odd" n += (mod_four - n)%4 while not is_prime(n): n += shift return n def next_prime(n, mod_four=None): """ Returns the next prime larger than n. Input: n -- a positive integer mod_four -- (optional argument) an integer that will equal the residue (either 1 or 3) of the returned prime modulo 4. Output: int -- a probable prime """ # is_prime is a function of one argument n that returns # either True if n is (probably) prime and False otherwise. # I choose miller_rabin since it is defined in this file. is_prime = miller_rabin if not mod_four: shift = 2 n += 1+(n%2) else: shift = 4 mod_four %= 4 assert mod_four in (1,3), "p must be odd" n += 1 n += (mod_four - n)%4 while not is_prime(n): n += shift return n ################################################## ## Computing the Legendre Symbol ################################################## def legendre(a, p): """ Returns the Legendre symbol a over p, where p is an odd prime. Input: a -- an integer p -- an odd prime (primality not checked) Output: int: -1 if a is not a square mod p, 0 if gcd(a,p) is not 1 1 if a is a square mod p. """ assert p%2 == 1, "p={0} isn't odd as it must be.".format(p) b = pow(a, (p-1)//2, p) if b == 1: return 1 elif b == p-1: return -1 else: return 0 ################################################## ## Computing square roots modulo a prime ################################################## def sqrtmod(a, p): """ Returns a square root of a modulo p. Input: a -- an integer that is a perfect square modulo p (this is checked) p -- a prime (primality not checked) Output: int -- a square root of a, as an integer between 0 and p-1. """ a %= p if a==0 or p==2: return a assert legendre(a, p) == 1, "{0} is not a square mod {1}.".format(a,p) if p%4 == 3: return pow(a, (p+1)//4, p) def mul(x, y): # multiplication in R return ((x[0]*y[0] + a*y[1]*x[1]) % p, \ (x[0]*y[1] + x[1]*y[0]) % p) def power(x, n): # exponentiation in R ans = (1,0) xpow = x while n != 0: if n%2 != 0: ans = mul(ans, xpow) xpow = mul(xpow, xpow) n //= 2 return ans while True: z = randrange(1,p) u, v = power((1,z), (p-1)//2) if v != 0: vinv = inversemod(v, p) for x in [-u*vinv, (1-u)*vinv, (-1-u)*vinv]: if (x*x)%p == a: return x%p assert False, "Bug in sqrtmod." ################################################## ## Continued Fractions ################################################## def convergents(v): """ Returns the partial convergents of the continued fraction v. Input: v -- list of integers [a0, a1, a2, ..., am] Output: list -- list [(p0,q0), (p1,q1), ...] of pairs (pm,qm) such that the mth convergent of v is pm/qm. """ w = [(0,1), (1,0)] for n,vn in enumerate(v): # The indices of w are shifted by 2. They # start at 0 instead of the expected -2. pn = vn*w[n+1][0] + w[n][0] qn = vn*w[n+1][1] + w[n][1] w.append((pn, qn)) w.pop(0); w.pop(0) # remove first entries of w return w def contfrac_rat(numer, denom): """ Returns the continued fraction of the rational number numer/denom. Input: numer -- an integer denom -- a positive integer coprime to num Output list -- the continued fraction [a0, a1, ..., am] of the rational number num/denom. list -- the list [(p0,q0), (p1,q1), ...] of pairs (pm,qm) such that the mth convergent of continued fraction is pm/qm. """ assert denom > 0, "denom={0} isn't positive as it must be.".format(denom) a = numer; b = denom v = [] while b != 0: v.append(a//b) (a, b) = (b, a%b) return v,convergents(v) def contfrac_float(x): """ Returns the continued fraction of the floating point number x, computed using the continued fraction procedure, and the sequence of partial convergents. Input: x -- a floating point number (decimal) Output: list -- the continued fraction [a0, a1, ...] obtained by applying the continued fraction procedure to x to the precision of this computer. list -- the list [(p0,q0), (p1,q1), ...] of pairs (pm,qm) such that the mth convergent of continued fraction is pm/qm. """ v = [] w = [(0,1), (1,0)] # keep track of convergents start = x while True: a = int(x) v.append(a) n = len(v)-1 pn = v[n]*w[n+1][0] + w[n][0] qn = v[n]*w[n+1][1] + w[n][1] w.append((pn, qn)) x -= a if abs(start - float(pn)/float(qn)) == 0: w.pop(0); w.pop(0) # remove first entries of w return v, w x = 1/x def sum_of_two_squares(p): """ Uses continued fractions to efficiently compute a representation of the prime p as a sum of two squares. The prime p must be 1 modulo 4. Input: p -- a prime congruent 1 modulo 4. Output: integers a, b such that p = a**2 + b**2 """ assert miller_rabin(p), "p={0} is not a prime as it must be.".format(p) assert p%4 == 1, "p={0} is not congruent to 1 modulo 4 as it must be.".format(p) r = sqrtmod(-1, p) v,w = contfrac_rat(r, p) n = int(sqrt(p)) for a, b in convergents(v): c = abs(r*b - p*a) if c <= n: return (b,c) assert False, "Bug in sum_of_two_squares." def is_square(n): m = int(sqrt(n)) return n == m**2 def sum_of_two_squares_naive(n): """ Uses brute force search to inefficiently compute a representation of the number n as a sum of two squares. Input: n -- a positive integer Output: integers a, b such that n = a**2 + b**2 """ for i in range(int(sqrt(n))+1): if is_square(n - i**2): return i, int(sqrt(n-i**2)) assert False, "{0} is not a sum of two squares".format(n) ##################################################