619

 avatar
unknown
python
2 hours ago
3.0 kB
82
No Index
MOD = 10**9 + 7

def sieve(n):
    """Return list of primes <= n using a simple optimized sieve."""
    if n < 2:
        return []
    sieve = bytearray(b"\x01") * (n + 1)
    sieve[0:2] = b"\x00\x00"
    primes = []
    for i in range(2, n + 1):
        if sieve[i]:
            primes.append(i)
            step = i
            start = i * i
            if start > n:
                continue
            sieve[start:n+1:step] = b"\x00" * (((n - start)//step) + 1)
    return primes


def C_square_subsets(a, b, mod=MOD, verbose=False):
    """
    Compute C(a,b): number of non-empty subsets of {a,...,b} whose product is a perfect square,
    modulo mod.
    """
    n = b - a + 1
    max_val = b

    # All prime factors of numbers in [a,b] are <= b, so we only need primes up to sqrt(b)
    limit = int(max_val**0.5) + 1
    small_primes = sieve(limit)

    if verbose:
        print(f"sqrt(b) ≈ {limit}, small primes: {len(small_primes)}")

    # Map each prime to a bit index
    prime_to_bit = {}
    bit_count = 0
    for p in small_primes:
        prime_to_bit[p] = bit_count
        bit_count += 1

    # Basis for vectors over GF(2): pivot_bit -> vector
    basis = {}

    for x in range(a, b + 1):
        y = x
        v = 0

        # Factor using small primes; record parity (odd/even) of each prime exponent
        for p in small_primes:
            if p * p > y:
                break
            if y % p == 0:
                cnt = 0
                while y % p == 0:
                    y //= p
                    cnt ^= 1  # flip parity
                if cnt:
                    v ^= (1 << prime_to_bit[p])

        # Whatever remains > 1 is a prime factor with exponent 1
        if y > 1:
            idx = prime_to_bit.get(y)
            if idx is None:
                idx = bit_count
                prime_to_bit[y] = idx
                bit_count += 1
            v ^= (1 << idx)

        # Gaussian elimination over GF(2) using bitmasks
        xvec = v
        while xvec:
            pivot = xvec.bit_length() - 1   # highest set bit
            bvec = basis.get(pivot)
            if bvec is None:
                basis[pivot] = xvec
                break
            xvec ^= bvec  # eliminate pivot

    r = len(basis)
    if verbose:
        print(f"n = {n}, rank = {r}, bit dimensions used = {bit_count}")

    # Number of solutions to A * x = 0 over GF(2) is 2^(n-r)
    # subtract 1 to remove the empty subset
    return (pow(2, n - r, mod) - 1) % mod


if __name__ == "__main__":
    # Sanity checks with the given examples:
    print("C(5,10) =", C_square_subsets(5, 10, verbose=True))         # should be 3
    print("C(40,55) =", C_square_subsets(40, 55, verbose=True))       # should be 15
    print("C(1000,1234) mod =", C_square_subsets(1000, 1234, verbose=True))  # 975523611

    # Now the big one:
    print("C(1000000,1234567) mod 1e9+7 =",
          C_square_subsets(1_000_000, 1_234_567, verbose=True))
Editor is loading...