#!/usr/bin/env python3
"""
Computation 59 -- Lemma 7: uniform-in-k closure of the refined bridge at finite N
==================================================================================
Lemma 5(a, b, c) (Comps 56-58) closed the L_comm ratio match analytically:
for every k >= 1 there exist (m(k), alpha(k)) such that the 2-monomial
bridge symbol f_k(w) = w^(m(k)) + alpha(k) * w^(m(k)+1) achieves

    sup_{partial B^2} || M_{f_k} ||_op  /  sup_{partial B^2} | f_k |  =  2 sqrt(k)

EXACTLY on the INFINITE Bergman space H^2_alpha(B^2).  The remaining
structural question is how the gap behaves at FINITE Bergman truncation
H^2_alpha(B^2)^{(N)} (monomials z_0^a z_1^b with a + b <= N) and whether
it is UNIFORM in k as the substrate size D grows.

Setup
-----
Weighted Bergman space H^2_alpha(B^2) with the orthonormal basis

    e_{a, b}^(alpha) (z)  =  z_0^a z_1^b  /  sqrt( a! b! Gamma(alpha + 3) / Gamma(a + b + alpha + 3) )

SU(2) generators (acting on holomorphic polynomials):

    J_+ = z_0 d/dz_1,    J_- = z_1 d/dz_0,    J_z = (1/2)(z_0 d/dz_0 - z_1 d/dz_1)

Round-S^3 Dirac on H^2_alpha(B^2) (X) C^2:  D_alpha = sum_a J_a (X) sigma_a.

Holomorphic Toeplitz operator T_g for holomorphic symbol g:  T_g h = g h
(no Bergman projection needed since g h stays holomorphic).

Refined 2-monomial bridge for weight k:
    f_k(w)  =  w^{m(k)} + alpha(k) * w^{m(k)+1}      with w := z_0 z_1
    (m(k), alpha(k)) from Lemma 5(b, c)  (Comp 57, 58).

L_comm criterion (relaxed) at level (D, k, N):

    gamma(D, k, N)  :=  | || [D_alpha, T_{f_k} (X) I_2] ||_op  /  || T_{f_k} ||_op  -  2 sqrt(k) |  /  (2 sqrt(k)).

Lemma 7 conjecture (uniform-in-k closure rate)
----------------------------------------------
For the EXACT bridge with the Lemma 5(c) closing parameters, the
truncation gap satisfies

    gamma(D, k, N)  =  O(1 / N)  uniformly in k <= k_D = floor(sqrt(D))

provided N >= 2 k_D + 2 (so the operator T_{f_k} acts bijectively on
the truncated Bergman support).  Hence with the matched scaling
N(D) ~ 2 sqrt(D), gamma_D = O(1 / sqrt(D)), uniformly in k.

This computation verifies the conjecture across D = 4, 6, 8 (the
substrate sizes used by the L_round closure verification) and
k = 1, ..., k_D, at several N >= 2 k_D + 2.
"""
import math
import numpy as np
import numpy.linalg as la
from numpy.linalg import norm


# ---------- Lemma 5(c) closing parameters from Comp 58 ----------

def ratio_cf(m):
    if m == 1:
        return 2.0
    return m ** (1 - m) * (m + 1) ** ((m + 1) / 2.0) * (m - 1) ** ((m - 1) / 2.0)


def m_of_k(k):
    """Smallest m with ratio(m+1) > 2 sqrt(k)."""
    target = 2.0 * math.sqrt(k)
    if k == 1:
        return 1
    m = 1
    while ratio_cf(m + 1) < target:
        m += 1
    return m


def alpha_of_q(m, q):
    return 2 * m * (1 - m * q) / (
        (m + 1) * math.sqrt(1 - q * q) * ((m + 1) * q - 1)
    )


def R_m_parametric(m, q):
    A = ((1 + q) / 2) ** ((m + 1) / 2) * ((1 - q) / 2) ** ((m - 1) / 2)
    h_factor = m * q / ((m + 1) * q - 1)
    h_max = A * h_factor
    a = alpha_of_q(m, q)
    sup_f = (0.5) ** m * (1 + a / 2)
    return h_max / sup_f


def alpha_of_k(k):
    """Closing alpha(k) from Lemma 5(c)."""
    if k == 1:
        return 0.0  # bare single monomial w^1 already closes
    m = m_of_k(k)
    target = 2.0 * math.sqrt(k)
    q_lo, q_hi = 1.0 / (m + 1) + 1e-9, 1.0 / m - 1e-9
    R_hi = R_m_parametric(m, q_hi)
    for _ in range(80):
        q_mid = 0.5 * (q_lo + q_hi)
        R_mid = R_m_parametric(m, q_mid)
        if (R_mid - target) * (R_hi - target) < 0:
            q_lo = q_mid
        else:
            q_hi, R_hi = q_mid, R_mid
    q_star = 0.5 * (q_lo + q_hi)
    return alpha_of_q(m, q_star)


# ---------- Truncated weighted Bergman space ----------

def bergman_basis(N):
    """Orthonormal monomial basis of H^2_alpha(B^2) truncated to degree <= N."""
    return [(a, b) for a in range(N + 1) for b in range(N + 1 - a)]


def bergman_inner_norm_sq(a, b, alpha):
    """The squared L^2-norm of z_0^a z_1^b under the weighted Bergman measure
    (alpha + 3) * (1 - |z|^2)^alpha / pi^2 on B^2.  Standard formula:

        || z_0^a z_1^b ||^2  =  a! b! * Gamma(alpha + 3) / Gamma(a + b + alpha + 3).
    """
    # use lgamma for numerical stability
    log_norm_sq = (
        math.lgamma(a + 1)
        + math.lgamma(b + 1)
        + math.lgamma(alpha + 3)
        - math.lgamma(a + b + alpha + 3)
    )
    return math.exp(log_norm_sq)


def normalised_basis(basis, alpha):
    """Returns the inverse of sqrt(norm) so we can convert monomial coefficients to ONB coefficients."""
    return np.array(
        [1.0 / math.sqrt(bergman_inner_norm_sq(a, b, alpha)) for (a, b) in basis]
    )


# ---------- SU(2) generators in the monomial basis ----------

def J_matrices(basis, alpha):
    """Builds J_+, J_-, J_z as matrices in the ORTHONORMAL Bergman basis (degree <= N)."""
    dim = len(basis)
    idx = {b: i for i, b in enumerate(basis)}
    inv_norm = normalised_basis(basis, alpha)

    # In the *unnormalised* monomial basis:
    #   J_+ (z_0^a z_1^b) = b * z_0^(a+1) z_1^(b-1)
    #   J_- (z_0^a z_1^b) = a * z_0^(a-1) z_1^(b+1)
    #   J_z (z_0^a z_1^b) = (a - b)/2 * z_0^a z_1^b
    #
    # In the ORTHONORMAL basis e_{a,b} = z_0^a z_1^b / N_{a, b}:
    #   J_+ e_{a, b}  =  b * (N_{a+1, b-1} / N_{a, b}) * e_{a+1, b-1}.
    Jp = np.zeros((dim, dim), dtype=complex)
    Jm = np.zeros((dim, dim), dtype=complex)
    Jz = np.zeros((dim, dim), dtype=complex)
    for (a, b), i in idx.items():
        # J_+
        if b >= 1 and (a + 1, b - 1) in idx:
            j = idx[(a + 1, b - 1)]
            ratio = math.sqrt(
                bergman_inner_norm_sq(a + 1, b - 1, alpha)
                / bergman_inner_norm_sq(a, b, alpha)
            )
            Jp[j, i] += b * ratio
        # J_-
        if a >= 1 and (a - 1, b + 1) in idx:
            j = idx[(a - 1, b + 1)]
            ratio = math.sqrt(
                bergman_inner_norm_sq(a - 1, b + 1, alpha)
                / bergman_inner_norm_sq(a, b, alpha)
            )
            Jm[j, i] += a * ratio
        # J_z (diagonal)
        Jz[i, i] += (a - b) / 2.0
    return Jp, Jm, Jz


def T_w_pow(basis, alpha, n):
    """Holomorphic Toeplitz operator for symbol w^n = (z_0 z_1)^n on the truncated Bergman
    space.  Maps z_0^a z_1^b -> z_0^(a+n) z_1^(b+n); projects to truncated support."""
    dim = len(basis)
    idx = {b: i for i, b in enumerate(basis)}
    M = np.zeros((dim, dim), dtype=complex)
    for (a, b), i in idx.items():
        target = (a + n, b + n)
        if target in idx:
            j = idx[target]
            ratio = math.sqrt(
                bergman_inner_norm_sq(a + n, b + n, alpha)
                / bergman_inner_norm_sq(a, b, alpha)
            )
            M[j, i] += ratio
    return M


def T_bridge(basis, alpha, k):
    """Refined 2-monomial bridge T_{f_k} with f_k = w^m + alpha(k) w^(m+1)."""
    m = m_of_k(k)
    a_k = alpha_of_k(k)
    T_m = T_w_pow(basis, alpha, m)
    T_mp1 = T_w_pow(basis, alpha, m + 1)
    return T_m + a_k * T_mp1


def alpha_bridge_norm_op(T_b):
    return float(la.norm(T_b, ord=2))


def sigma_matrices():
    sx = np.array([[0, 1], [1, 0]], dtype=complex)
    sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
    sz = np.array([[1, 0], [0, -1]], dtype=complex)
    return sx, sy, sz


def D_alpha(Jp, Jm, Jz):
    """SU(2) Dirac D_alpha = J_a (X) sigma_a, with J_1 = (J_+ + J_-)/2, J_2 = (J_+ - J_-)/(2i)."""
    sx, sy, sz = sigma_matrices()
    J1 = (Jp + Jm) / 2.0
    J2 = (Jp - Jm) / (2.0j)
    return np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)


def lcomm_ratio(T_b, D_op, dim, I2):
    """The ratio || [D_alpha, T_b (X) I_2] ||_op  /  || T_b ||_op."""
    T_full = np.kron(T_b, I2)
    comm = D_op @ T_full - T_full @ D_op
    num = float(la.norm(comm, ord=2))
    den = alpha_bridge_norm_op(T_b)
    return num / den


def main():
    print("=" * 90)
    print("  Computation 59 -- Lemma 7: uniform-in-k closure of the refined bridge at finite N")
    print("=" * 90)
    print()
    print("  Lemma 5(c) (Comp 58) gives the exact closing alpha(k); this computation")
    print("  measures the relaxed L_comm gap at finite Bergman truncation N for the")
    print("  EXACT 2-monomial bridge and confirms the uniform-in-k decay as N grows.")
    print()
    print("  alpha_Bergman is set to 0 throughout (canonical Bergman weight on B^2);")
    print("  the alpha-dependence enters only the bridge norms, not the J_a structure.")
    print()

    alpha_Bergman = 0.0
    I2 = np.eye(2, dtype=complex)

    # Compute alpha_k once for the k values we will use
    print(f"  {'k':>2}  {'2 sqrt(k)':>10}  {'m(k)':>5}  {'alpha(k)':>10}")
    for k in range(1, 5):
        print(f"  {k:>2}  {2*math.sqrt(k):>10.6f}  {m_of_k(k):>5}  {alpha_of_k(k):>+10.4f}")
    print()

    for D in (4, 6, 8):
        k_D = int(math.floor(math.sqrt(D)))
        print("=" * 90)
        print(f"  D = {D},  k_D = floor(sqrt(D)) = {k_D}")
        print("=" * 90)
        # need N >= 2 k_D + 2 for the bridge to fit fully into the truncated space
        N_min = 2 * k_D + 2
        Ns = [N_min, N_min + 2, N_min + 4, N_min + 6]
        print(f"  {'N':>4}  {'dim H_N':>8}", end="  ")
        for k in range(1, k_D + 1):
            print(f"k={k} gap (rel)", end="  ")
        print()
        for N in Ns:
            basis = bergman_basis(N)
            dim = len(basis)
            Jp, Jm, Jz = J_matrices(basis, alpha_Bergman)
            D_op = D_alpha(Jp, Jm, Jz)
            row = f"  {N:>4}  {dim:>8}  "
            for k in range(1, k_D + 1):
                T_b = T_bridge(basis, alpha_Bergman, k)
                if la.norm(T_b) < 1e-12:
                    row += f"   (T = 0)    "
                    continue
                R = lcomm_ratio(T_b, D_op, dim, I2)
                target = 2.0 * math.sqrt(k)
                rel_gap = (R - target) / target
                row += f"   {rel_gap:+8.4f}    "
            print(row)
        print()

    print("=" * 90)
    print("  Findings (Lemma 7)")
    print("=" * 90)
    print()
    print("  Reading the rows: at each substrate size D and each Bergman cutoff N,")
    print("  the relative gap (R - 2 sqrt(k)) / (2 sqrt(k)) is reported for")
    print("  k = 1, ..., k_D.  Values close to 0 = closure; positive = overshoot;")
    print("  negative = undershoot.  Lemma 7 conjecture: gap = O(1/N) uniformly in k.")
    print()
    print("  Expected: at N >= 2 k_D + 2 the bridge fits; the gap should be small")
    print("  and decreasing with N at every k.  Uniformity in k is the new claim;")
    print("  Comps 33-36 measured the same quantity for a DIFFERENT bridge (multiplicative")
    print("  Mosco-averaged) where gap = O(1) and never closed.  With the EXACT closing")
    print("  alpha(k) from Lemma 5(c) the gap should be much smaller and N-controlled.")


if __name__ == "__main__":
    main()
