import numpy as np
from erm_code_soln import create_vandermonde, solve_linear_LS, solve_linear_LS_gd
from sklearn.metrics import confusion_matrix

np.set_printoptions(suppress=True, linewidth=120)
rng = np.random.default_rng(42)

# ------------------------------------------------------------------------------
# Desired sample outputs (approximate) for comparison
# ------------------------------------------------------------------------------
# Vandermonde Example 1:
#  [[1 1 1]
#   [1 3 9]
#   [1 2 4]]
#
# Vandermonde Example 2:
#  [[ 1 -2  4 -8]
#   [ 1 -1  1 -1]
#   [ 1  0  0  0]
#   [ 1  1  1  1]
#   [ 1  2  4  8]]
#
# solve_linear_LS Example 1:
#  [-4.    6.5  -1.5 ]
#
# solve_linear_LS Example 2:
#  [ 1.25714286  0.58333333  0.07142857 -0.08333333]
#
# Polynomial fit sanity (as in notebook):
#  z_hat approx [3.3846 1.9807 0.7115]
#  MSE approx 2.7308


def assert_close(a, b, tol=1e-6, label="value"):
    err = np.max(np.abs(a - b))
    ok = err <= tol
    print(f"  {label}: max |Δ|={err:.3e} {'OK' if ok else 'NOT OK'}")
    return ok


def poly_mse(A, y, z):
    r = y - A @ z
    return float(np.mean(r**2))


def test_vandermonde_and_lstsq():
    print("\n=== Test 1: Vandermonde + Closed-form LS ===")
    A1 = create_vandermonde(np.asarray([1, 3, 2]), 2)
    print("Vandermonde Example 1:\n", A1, "\n")

    A2 = create_vandermonde(np.arange(-2, 3), 3)
    print("Vandermonde Example 2:\n", A2, "\n")

    z1 = solve_linear_LS(A1, np.asarray([1, 2, 3]))
    print("solve_linear_LS Example 1:\n", z1, "\n")
    assert_close(z1, np.array([-4.0, 6.5, -1.5]), tol=1e-8, label="z1")

    z2 = solve_linear_LS(A2, np.asarray([1, 1, 1, 2, 2]))
    print("solve_linear_LS Example 2:\n", z2, "\n")
    z2_ref = np.array([1.25714286, 0.58333333, 0.07142857, -0.08333333])
    assert_close(z2, z2_ref, tol=1e-6, label="z2")


def test_gradient_descent_vs_closed_form():
    print("\n=== Test 2: Gradient Descent vs Closed-form LS ===")
    # Random tall system
    m, n = 80, 8
    A = rng.normal(size=(m, n))
    z_true = rng.normal(size=(n,))
    y = A @ z_true + 0.01 * rng.normal(size=(m,))

    z_ls = solve_linear_LS(A, y)
    z_gd = solve_linear_LS_gd(A, y, step=5e-3, niter=30_000)

    print("z_LS (first 5):", z_ls[:5])
    print("z_GD (first 5):", z_gd[:5], "\n")

    mse_ls = poly_mse(A, y, z_ls)
    mse_gd = poly_mse(A, y, z_gd)
    print(f"MSE LS: {mse_ls:.6e}, MSE GD: {mse_gd:.6e}")
    assert_close(z_gd, z_ls, tol=4e-3, label="z_GD vs z_LS (coeff)")
    assert_close(np.array([mse_gd]), np.array([mse_ls]), tol=4e-3, label="MSE(GD) vs MSE(LS)")


def test_notebook_poly_fit():
    print("\n=== Test 3: Notebook Polynomial Fit (sanity) ===")
    # From notebook: points (-2,2),(1,3),(0,5),(1,7),(2,11), quadratic fit
    x = np.array([-2, 1, 0, 1, 2])
    y = np.array([2, 3, 5, 7, 11])
    m = 2
    A = create_vandermonde(x, m)
    z_hat = solve_linear_LS(A, y)
    mse = poly_mse(A, y, z_hat)

    print("z_hat:", z_hat)
    print(f"MSE: {mse:.4f}\n")
    # Expect approx [3.3846,1.9807,0.7115], mse ~ 2.7308
    assert_close(z_hat, np.array([3.3846,1.9807,0.7115]), tol=2e-4, label="notebook z_hat")
    assert_close(np.array([mse]), np.array([2.7308]), tol=2e-4, label="notebook MSE")

    # Gradient descent on same system should be close
    z_gd = solve_linear_LS_gd(A, y, step=1e-3, niter=30_000)
    mse_gd = poly_mse(A, y, z_gd)
    print("z_hat (GD):", z_gd)
    print(f"MSE (GD): {mse_gd:.4f}")
    assert_close(z_gd, z_hat, tol=4e-3, label="notebook GD ~ LS (coeff)")
    assert_close(np.array([mse_gd]), np.array([mse]), tol=4e-3, label="notebook GD ~ LS (MSE)")


def test_regularization_stability():
    print("\n=== Test 4:  Regularization on Ill-conditioned Vandermonde ===")
    # A notoriously ill-conditioned Vandermonde
    x = np.linspace(-1, 1, 41)
    deg = 12
    A = create_vandermonde(x, deg)
    # Fit noisy quadratic ground-truth
    y = 0.3 + 0.5 * x + 0.2 * x**2 + 0.02 * rng.normal(size=x.shape)

    z_unreg = solve_linear_LS(A, y, reg=0.0)
    z_reg = solve_linear_LS(A, y, reg=1e-2)

    norm_unreg = np.linalg.norm(z_unreg)
    norm_reg = np.linalg.norm(z_reg)
    mse_unreg = poly_mse(A, y, z_unreg)
    mse_reg = poly_mse(A, y, z_reg)

    print(f"||z_unreg||={norm_unreg:.3e}, MSE_unreg={mse_unreg:.3e}")
    print(f"||z_reg  ||={norm_reg:.3e}, MSE_reg  ={mse_reg:.3e}")
    print("  Expect regularized norm to be noticeably smaller (more stable).")


if __name__ == "__main__":
    # 1) Basic Vandermonde and LS sanity
    test_vandermonde_and_lstsq()

    # 2) Gradient descent consistency with LS
    test_gradient_descent_vs_closed_form()

    # 3) Polynomial fit from the notebook (sanity)
    test_notebook_poly_fit()

    # 4) Ridge regularization reduces coefficient norm on ill-conditioned system
    test_regularization_stability()

    print("\nAll tests completed.")
