#!/usr/bin/env python3
"""
Computation 54 -- Identifying the S_D irrep that saturates sigma_2 = D - 2
==========================================================================
Comp 52 numerically established that the off-diagonal block
T_pm = P_+ T_tail P_-  has sigma_max = D - 2 on the orthogonal complement of
the S_D-symmetric subspace.  Comp 53 ruled out the standard rep [D-1, 1]
(sigma_max|_[D-1, 1] = 1.63 at D = 4, ..., 5.40 at D = 8 -- strictly < D - 2).

This computation projects T_pm onto each 2-row irrep [D-j, j] via character
projectors and reports sigma_max within each block.  The block that gives
sigma_max = D - 2 is the saturating irrep.

Method
------
  P_lambda  =  (dim lambda / |S_D|) sum_{g in S_D} chi_lambda(g) rho(g)
where rho(g) permutes the D bits of basis |x> in C^{2^D}.

The 2-row irreps [D-j, j] are the only ones appearing in the permutation
representation on weighted subsets (Young's branching rule on k-subsets:
the permutation rep on k-subsets of [D] decomposes as the direct sum of
[D-j, j] for j = 0, ..., min(k, D-k)).

Frobenius character formula
---------------------------
  chi_{[D-j, j]}(mu)  =  coeff of x^{D-j+1} y^j in
       (x - y) * prod_i (x^{mu_i} + y^{mu_i})
where mu is the cycle type of g, j >= 1.  For j = 0 (trivial rep [D]) the
character is the constant 1.

Native runtime budget
---------------------
  D = 4: < 0.1 s
  D = 5: ~ 0.3 s
  D = 6: ~ 2 s
  D = 7: ~ 30 s
  D = 8 not run in-browser (would take a few minutes).
"""
import math
from itertools import permutations
from functools import lru_cache
import numpy as np
import numpy.linalg as la


def popcount(x):
    c = 0
    while x:
        c += x & 1
        x >>= 1
    return c


def build_T_tail_diagonal(D, kD):
    """T_tail = sum_{|S| > kD} chi_S is diagonal in the computational basis;
    its eigenvalue on |x> depends only on the Hamming weight of x via a
    Krawtchouk-polynomial sum."""
    c_by_weight = []
    for w in range(D + 1):
        total = 0
        for j in range(kD + 1, D + 1):
            kraw = 0
            for i in range(0, min(w, j) + 1):
                if j - i <= D - w:
                    kraw += ((-1) ** i) * math.comb(w, i) * math.comb(D - w, j - i)
            total += kraw
        c_by_weight.append(total)
    dim = 1 << D
    return np.array([c_by_weight[popcount(x)] for x in range(dim)], dtype=complex)


def build_D_sub(D):
    """D_sub = sum_a chi_a^Cliff (Jordan-Wigner Clifford generators).
       chi_a^Cliff on |x> flips bit a with sign (-1)^{# set bits below a}."""
    n = 1 << D
    M = np.zeros((n, n), dtype=complex)
    for a in range(D):
        for x in range(n):
            xp = x ^ (1 << a)
            sign = 1
            for b in range(a):
                if (x >> b) & 1:
                    sign *= -1
            M[xp, x] += sign
    return M


def spectral_projectors(D_sub_mat, D):
    """D_sub^2 = D * I, so D_sub has eigenvalues +- sqrt(D).
       P_+ = (I + D_sub / sqrt(D)) / 2,  P_- = (I - D_sub / sqrt(D)) / 2."""
    n = D_sub_mat.shape[0]
    sqrtD = math.sqrt(D)
    Pp = (np.eye(n, dtype=complex) + D_sub_mat / sqrtD) / 2.0
    Pm = (np.eye(n, dtype=complex) - D_sub_mat / sqrtD) / 2.0
    return Pp, Pm


def cycle_type(g):
    """Return cycle type of permutation g as sorted-descending tuple."""
    n = len(g)
    seen = [False] * n
    cycles = []
    for i in range(n):
        if not seen[i]:
            length = 0
            j = i
            while not seen[j]:
                seen[j] = True
                j = g[j]
                length += 1
            cycles.append(length)
    return tuple(sorted(cycles, reverse=True))


@lru_cache(maxsize=None)
def chi_2row(D, j, ctype):
    """Frobenius character of S_D irrep [D-j, j] on cycle type ctype."""
    if j == 0:
        return 1
    # Build the polynomial as a dict { (deg_x, deg_y): coeff }.
    poly = {(0, 0): 1}
    for mu_i in ctype:
        new_poly = {}
        for (a, b), c in poly.items():
            new_poly[(a + mu_i, b)] = new_poly.get((a + mu_i, b), 0) + c
            new_poly[(a, b + mu_i)] = new_poly.get((a, b + mu_i), 0) + c
        poly = new_poly
    final = {}
    for (a, b), c in poly.items():
        final[(a + 1, b)] = final.get((a + 1, b), 0) + c
        final[(a, b + 1)] = final.get((a, b + 1), 0) - c
    return final.get((D - j + 1, j), 0)


def perm_action_on_H(g, D):
    """FERMIONIC-signed permutation rho_F(g) on H = C^{2^D}.

    Naive bit-permutation does NOT commute with D_sub = sum_a chi_a^Cliff
    because each chi_a^Cliff = sigma_z^a sigma_x^{(a)} I^{...} carries a
    Jordan-Wigner string that depends on the ordering of sites.  The
    representation that DOES commute with D_sub (and with T_tail, P_+, P_-)
    is the fermionic one:

        rho_F(g) |S>  =  sgn(g | S) * |g(S)>

    where |S> = chi_S^Cliff |0> with S sorted ascending, and sgn(g | S) is
    the sign of the permutation that sorts (g(a_1), g(a_2), ..., g(a_k)) into
    ascending order (k = |S|).

    Returns (perm, sign) with rho_F(g) |x> = sign[x] * |perm[x]>.
    """
    N = 1 << D
    perm = np.zeros(N, dtype=np.int64)
    sign = np.ones(N, dtype=np.int64)
    for x in range(N):
        S = [a for a in range(D) if (x >> a) & 1]
        images = [g[a] for a in S]
        inv = 0
        for i in range(len(images)):
            for j in range(i + 1, len(images)):
                if images[i] > images[j]:
                    inv += 1
        sign[x] = 1 if (inv % 2 == 0) else -1
        xp = 0
        for a in S:
            xp |= 1 << g[a]
        perm[x] = xp
    return perm, sign


def dim_2row(D, j):
    """dim of S_D irrep [D-j, j] = C(D, j) - C(D, j-1) for j >= 1;
    dim 1 for j = 0 (trivial rep)."""
    if j == 0:
        return 1
    return math.comb(D, j) - math.comb(D, j - 1)


def build_projector_lambda(D, j):
    """P_lambda = (dim lambda / D!) sum_{g in S_D} chi_lambda(g) rho(g)."""
    N = 1 << D
    D_fac = math.factorial(D)
    dim_lambda = dim_2row(D, j)
    P = np.zeros((N, N), dtype=complex)
    cols = np.arange(N)
    for g_tuple in permutations(range(D)):
        ctype = cycle_type(g_tuple)
        chi = chi_2row(D, j, ctype)
        if chi == 0:
            continue
        perm, sign = perm_action_on_H(list(g_tuple), D)
        # rho_F(g_tuple) is the matrix M with M[perm[x], x] = sign[x].
        # Add chi * M to P:  P[perm[x], x] += chi * sign[x]  for all x.
        np.add.at(P, (perm, cols), chi * sign)
    P *= dim_lambda / D_fac
    return P


def main():
    print("=" * 90)
    print("  Computation 54  --  Identifying the S_D irrep that saturates sigma_2 = D - 2")
    print("=" * 90)
    print()
    print("  Method: project T_pm = P_+ T_tail P_- onto each 2-row irrep [D-j, j]")
    print("  via character projectors P_lambda = (dim/|G|) sum_g chi_lambda(g) rho(g),")
    print("  report sigma_max within each block.")
    print()
    print("  Comp 52: sigma_max(T_pm | non-sym) = D - 2 exactly for D = 4..10.")
    print("  Comp 53: sigma_max(T_pm | [D-1, 1]) is strictly < D - 2.")
    print("  -> the saturating irrep is one of [D-2, 2], [D-3, 3], ...")
    print()

    print("  Sanity check: rho_F(g) commutes with D_sub (the bit-permutation rep does NOT,")
    print("  because of the JW sigma_z strings).")
    print()
    print(f"  {'D':>3}  {'max_g ||[rho_F(g), D_sub]||':>30}  {'max_g ||[rho_bit(g), D_sub]||':>32}")
    for D in range(4, 7):
        D_mat = build_D_sub(D)
        N = 1 << D
        worst_F = 0.0
        worst_bit = 0.0
        for g_tuple in permutations(range(D)):
            # fermionic rep
            perm, sign = perm_action_on_H(list(g_tuple), D)
            R_F = np.zeros((N, N), dtype=complex)
            R_F[perm, np.arange(N)] = sign
            err = la.norm(R_F @ D_mat - D_mat @ R_F, ord=2)
            if err > worst_F:
                worst_F = err
            # plain bit-permutation rep
            R_bit = np.zeros((N, N), dtype=complex)
            R_bit[perm, np.arange(N)] = 1.0
            err = la.norm(R_bit @ D_mat - D_mat @ R_bit, ord=2)
            if err > worst_bit:
                worst_bit = err
        print(f"  {D:>3}  {worst_F:>30.2e}  {worst_bit:>32.4f}")
    print()

    kD = 2
    for D in range(4, 8):
        print("=" * 90)
        print(f"  D = {D},  k_D = {kD},  target sigma_2 = D - 2 = {D - 2}")
        print("=" * 90)
        D_sub_mat = build_D_sub(D)
        T_tail_diag = build_T_tail_diagonal(D, kD)
        T_tail = np.diag(T_tail_diag)
        Pp, Pm = spectral_projectors(D_sub_mat, D)
        T_pm = Pp @ T_tail @ Pm

        sigma_total = la.svd(T_pm, compute_uv=False)
        # report top 4 singular values
        head = [f"{s:.4f}" for s in sigma_total[:4]]
        print(f"  Top 4 sigma(T_pm):  {head}")
        print()

        print(f"  {'irrep':>10}  {'dim':>6}  {'multiplicity':>12}  {'sigma_max in block':>20}  {'saturates D-2?':>16}")
        for j in range(0, D // 2 + 1):
            dim_lam = dim_2row(D, j)
            # multiplicity of [D-j, j] in permutation rep on H = oplus_w H_w:
            # [D-j, j] appears once in H_w for w = j, j+1, ..., D-j.
            mult = D - 2 * j + 1 if 2 * j <= D else 0
            P_lambda = build_projector_lambda(D, j)
            T_pm_lambda = P_lambda @ T_pm @ P_lambda
            sigma = la.svd(T_pm_lambda, compute_uv=False)
            sigma_max = sigma[0]
            sat = "YES" if abs(sigma_max - (D - 2)) < 1e-6 else f"no"
            print(f"  [{D-j:>2}, {j:>2}]   {dim_lam:>6}  {mult:>12}  {sigma_max:>20.6f}  {sat:>16}")
        print()

    print("=" * 90)
    print("  Findings")
    print("=" * 90)
    print()
    print("  The irrep [D - j*, j*] that saturates sigma_max = D - 2 identifies the")
    print("  block where the sub-dominant SVD direction lives.  Closed-form analysis")
    print("  of T_pm restricted to this block, plus the closed form sigma_1 = 2^(D-1) - D")
    print("  on the trivial block (Comp 52), then gives the complete operator-norm bound")
    print("  ||T_pm||_op = max(sigma_1, sigma_2) = sigma_1 for D >= 4 (exponential gap).")


if __name__ == "__main__":
    main()
