#!/usr/bin/env python3
"""
Computation 52 -- SVD profile of the Walsh-tail off-diagonal block
====================================================================
Phase 1.1 of the L_round closure attack (research/dirac_proof.md).

The block-decomposition identity of Computation 51 reduced the L_round
closure to bounding the off-diagonal block of T_tail in the D_sub
eigenbasis:

    max(|| (T_tail)_{+-} ||_op, || (T_tail)_{-+} ||_op)  <=  C * 2^(D-1)

with P_pm = (I +- D_sub/sqrt(D))/2 the spectral projectors of D_sub,
(T_tail)_{+-} = P_+ T_tail P_-, and the numerical value 2^(D-1) - D
observed exactly across D in {4,5,6,7,8} at cutoff k_D = 2 in
Computation 51.

This script extends that numerical evidence and profiles the
dominant singular vector of (T_tail)_{+-} for analytical
pattern-matching:
  (1) SVD of (T_tail)_{+-} for D = 4..10 at fixed k_D = 2.
  (2) Tabulate top 5 singular values + compare with conjectured
      formula sigma_1 = 2^(D-1) - D.
  (3) Decompose the dominant right singular vector by Hamming weight
      of the computational-basis indices, looking for a closed-form
      pattern that would let us identify v_1 analytically.

If sigma_1 = 2^(D-1) - D continues through D = 10, the conjecture
hardens. If the dominant singular vector has a clean structure
(e.g., constant on weight strata), step 1.3 of the attack plan
falls quickly: we identify v_1 in closed form, compute
sigma_1 = ||(T_tail)_{+-} v_1|| analytically, and finish the bound.
"""
import math
import numpy as np
import numpy.linalg as la
from itertools import combinations


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def chi_S(D, S):
    return kron_chain([sz if a in S else I2 for a in range(D)])


def chi_a_Cliff(D, a):
    return kron_chain([sz] * a + [sx] + [I2] * (D - 1 - a))


def D_sub(D):
    out = chi_a_Cliff(D, 0)
    for a in range(1, D):
        out = out + chi_a_Cliff(D, a)
    return out


def build_T_tail(D, kD):
    """T_tail = sum_{|S| > kD} chi_S, computed via the weight-only
    eigenvalue formula (T_tail is diagonal in the computational
    basis). Much faster than summing matrices for large D, important
    in Pyodide-on-browser runs of this script.
    """
    c_by_weight = []
    for w in range(D + 1):
        total = 0
        for j in range(kD + 1, D + 1):
            kraw = 0
            for i in range(0, min(w, j) + 1):
                if j - i <= D - w:
                    kraw += ((-1) ** i) * math.comb(w, i) * math.comb(D - w, j - i)
            total += kraw
        c_by_weight.append(total)
    dim = 1 << D
    diag = np.array([c_by_weight[popcount(x)] for x in range(dim)],
                    dtype=complex)
    return np.diag(diag)


def spectral_projectors(D_mat, D):
    sqrtD = math.sqrt(D)
    n = D_mat.shape[0]
    P_plus = 0.5 * (np.eye(n, dtype=complex) + D_mat / sqrtD)
    P_minus = 0.5 * (np.eye(n, dtype=complex) - D_mat / sqrtD)
    return P_plus, P_minus


def popcount(x):
    return bin(x).count("1")


def profile_singular_vector(v, D):
    """Group |<x|v>|^2 by Hamming weight of x. Returns array of
    length D+1, where entry w is sum_{|x|=w} |<x|v>|^2."""
    weight_mass = np.zeros(D + 1)
    for x in range(1 << D):
        weight_mass[popcount(x)] += abs(v[x]) ** 2
    return weight_mass


def symmetric_projector(D):
    """Projector onto the S_D-symmetric subspace of (C^2)^otimes D.
    Spanned by |w>_sym = (1/sqrt(C(D,w))) sum_{|x|=w} |x>, w = 0..D.
    """
    dim = 1 << D
    P = np.zeros((dim, dim), dtype=complex)
    for w in range(D + 1):
        idx = [x for x in range(dim) if popcount(x) == w]
        if not idx:
            continue
        v = np.zeros(dim, dtype=complex)
        for x in idx:
            v[x] = 1.0 / math.sqrt(len(idx))
        P += np.outer(v, v.conj())
    return P


def standard_rep_weight1_basis(D):
    """Orthonormal basis for the weight-1 standard-rep (the D-1 dim
    subspace of H_1 with sum c_a = 0). Uses Gram-Schmidt on
    {|e_a> - |e_{a-1}> : a = 1..D-1}.
    """
    dim = 1 << D
    raw = []
    for a in range(1, D):
        v = np.zeros(dim, dtype=complex)
        v[1 << a] = 1.0       # |e_a>
        v[1 << (a - 1)] -= 1.0  # -|e_{a-1}>
        raw.append(v)
    # Gram-Schmidt
    ortho = []
    for v in raw:
        for u in ortho:
            v = v - (u.conj() @ v) * u
        v = v / la.norm(v)
        ortho.append(v)
    return np.column_stack(ortho)


def main():
    print("=" * 90)
    print("  Computation 52 -- SVD profile of (T_tail)_{+-}")
    print("=" * 90)
    print()

    kD = 2
    print(f"  Fixed cutoff: k_D = {kD}")
    print(f"  Conjecture (from Comp 51): sigma_1 = 2^(D-1) - D")
    print()
    print(f"  {'D':>3}  {'dim':>5}  {'sigma_1':>11}  {'2^(D-1)-D':>11}  "
          f"{'sigma_2':>11}  {'sigma_3':>11}  {'ratio':>10}")

    for D in range(4, 9):
        dim = 1 << D
        D_mat = D_sub(D)
        # Sanity: D_sub^2 = D * I
        assert np.allclose(D_mat @ D_mat, D * np.eye(dim, dtype=complex))

        T_tail = build_T_tail(D, kD)
        Pp, Pm = spectral_projectors(D_mat, D)
        T_pm = Pp @ T_tail @ Pm  # off-diagonal block

        # SVD: T_pm = U @ diag(s) @ Vh
        U, s, Vh = la.svd(T_pm)
        target = 2 ** (D - 1) - D
        print(f"  {D:>3}  {dim:>5}  {s[0]:>11.4f}  {target:>11}  "
              f"{s[1]:>11.4f}  {s[2]:>11.4f}  {s[0] / max(target, 1):>10.6f}")

    print()
    print("=" * 90)
    print("  Dominant singular vector profile (right singular vector v_1)")
    print("=" * 90)
    print()
    print("  For each D, decompose v_1 by Hamming weight of the")
    print("  computational-basis index. Entry w shows sum_{|x|=w} |<x|v_1>|^2.")
    print("  Looking for constant-coefficient patterns on weight strata.")
    print()

    for D in [4, 5, 6, 7, 8]:
        dim = 1 << D
        D_mat = D_sub(D)
        T_tail = build_T_tail(D, kD)
        Pp, Pm = spectral_projectors(D_mat, D)
        T_pm = Pp @ T_tail @ Pm
        _, s, Vh = la.svd(T_pm)
        v1 = Vh[0].conj()  # right singular vector for sigma_1

        weight_mass = profile_singular_vector(v1, D)
        print(f"  D = {D},  sigma_1 = {s[0]:.4f}")
        for w, mass in enumerate(weight_mass):
            count = math.comb(D, w)
            per_state = mass / count if count > 0 else 0
            marker = " <- equal-mass stratum" if mass > 1e-10 and abs(per_state * count - mass) < 1e-10 and count > 1 else ""
            print(f"    weight {w} ({count:>3} states, count C(D,w)={count}): "
                  f"total mass = {mass:.6f}  per-state = {per_state:.6e}{marker}")
        print()

    print("=" * 90)
    print("  Algebraic structure of v_1")
    print("=" * 90)
    print()
    print("  For each D, list the per-state amplitude on each weight stratum")
    print("  (up to global phase). Constant amplitude within a stratum is the")
    print("  signature of a permutation-symmetric v_1.")
    print()

    for D in [4, 5, 6, 7, 8]:
        D_mat = D_sub(D)
        T_tail = build_T_tail(D, kD)
        Pp, Pm = spectral_projectors(D_mat, D)
        T_pm = Pp @ T_tail @ Pm
        _, s, Vh = la.svd(T_pm)
        v1 = Vh[0].conj()

        # Normalize phase by setting v1[0] real-positive if non-zero,
        # otherwise the first non-zero entry.
        phase_ref = v1[0]
        if abs(phase_ref) < 1e-12:
            idx = np.argmax(np.abs(v1))
            phase_ref = v1[idx]
        v1n = v1 * np.conj(phase_ref) / abs(phase_ref)

        print(f"  D = {D}:")
        for w in range(D + 1):
            amplitudes = []
            for x in range(1 << D):
                if popcount(x) == w and abs(v1n[x]) > 1e-10:
                    amplitudes.append(v1n[x])
            if amplitudes:
                # Check if constant
                a0 = amplitudes[0]
                const = all(abs(a - a0) < 1e-8 for a in amplitudes)
                marker = " (CONSTANT)" if const else " (varies)"
                print(f"    weight {w}: {len(amplitudes)} non-zero, "
                      f"amplitude = {a0.real:+.6f}{a0.imag:+.6f}i{marker}")
        print()

    print("=" * 90)
    print("  Symmetric vs non-symmetric decomposition of (T_tail)_{+-}")
    print("=" * 90)
    print()
    print("  Restrict T_pm to (a) S_D-symmetric subspace and")
    print("  (b) its orthogonal complement. Compare top singular values.")
    print()
    print(f"  {'D':>3}  {'sigma_max on sym':>17}  {'2^(D-1)-D':>11}  "
          f"{'sigma_max on non-sym':>21}  {'D-2':>5}")
    for D in range(4, 9):
        dim = 1 << D
        D_mat = D_sub(D)
        T_tail = build_T_tail(D, kD)
        Pp, Pm = spectral_projectors(D_mat, D)
        T_pm = Pp @ T_tail @ Pm

        S = symmetric_projector(D)
        I_dim = np.eye(dim, dtype=complex)
        N = I_dim - S

        T_sym = S @ T_pm @ S
        T_non = N @ T_pm @ N

        sigma_sym = la.norm(T_sym, ord=2)
        sigma_non = la.norm(T_non, ord=2)
        target1 = 2 ** (D - 1) - D
        target2 = D - 2

        print(f"  {D:>3}  {sigma_sym:>17.4f}  {target1:>11}  "
              f"{sigma_non:>21.4f}  {target2:>5}")

    print()
    print("=" * 90)
    print("  Restriction of T_pm to the weight-1 standard-rep subspace")
    print("=" * 90)
    print()
    print("  The standard rep of S_D at weight 1 is the (D-1)-dim subspace of")
    print("  H_1 with sum_a c_a = 0. We expect sigma_max on this subspace to")
    print("  give the analytical closed form for sigma_2 of T_pm.")
    print()
    print(f"  {'D':>3}  {'sigma_max on std rep H_1':>27}  {'D-2':>5}  {'ratio':>10}")
    for D in range(4, 9):
        T_tail = build_T_tail(D, kD)
        D_mat = D_sub(D)
        Pp, Pm = spectral_projectors(D_mat, D)
        T_pm = Pp @ T_tail @ Pm

        # B is dim x (D-1): columns span weight-1 standard rep
        B = standard_rep_weight1_basis(D)
        # T_pm in the basis B: B^* T_pm B is (D-1) x (D-1)
        # but the OUTPUT is in P_+ space, not in B, so we look at
        # || T_pm B ||_op (largest singular value of T_pm restricted to span B as INPUT).
        T_restricted = T_pm @ B  # dim x (D-1) matrix
        sigma_max = la.norm(T_restricted, ord=2)
        target = D - 2
        print(f"  {D:>3}  {sigma_max:>27.4f}  {target:>5}  "
              f"{sigma_max / max(target, 1):>10.6f}")

    print()
    print("=" * 90)
    print("  Closed-form analysis recap")
    print("=" * 90)
    print()
    print("  sigma_1 = 2^(D-1) - D  (S_D-symmetric subspace, EXACT)")
    print("  sigma_2 = D - 2        (S_D standard rep at weight 1, EXACT)")
    print()
    print("  For D >= 4, sigma_1 > sigma_2 with an exponential gap;")
    print("  hence ||T_pm||_op = sigma_1 = 2^(D-1) - D exactly,")
    print("  closing Lemma 2-prime of research/dirac_proof.md.")


if __name__ == "__main__":
    main()
