# python 3.12
# pip install torch devinterp matplotlib scikit-learn

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from devinterp.slt.callback import SamplerCallback
from devinterp.slt.llc import LLCEstimator
from devinterp.slt.sampler import sample
from devinterp.optim import SGLD
from sklearn.decomposition import PCA
from matplotlib.animation import FuncAnimation, FFMpegWriter
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
import os
import warnings
from tqdm import tqdm

warnings.filterwarnings("ignore")

# ==========================================
# CONFIGURATION
# ==========================================
P = 67
DEVICE = "cpu"
TRAIN_PCT = 0.45
BATCH_SIZE = 512
ESTIMATION_INTERVAL = 100
EPOCHS = 10000

# SGLD hyperparameters for LLC estimation.
# Multiple chains let us detect mixing failures and report uncertainty.
# Burnin >= num_draws is recommended by devinterp.
SGLD_LR = 1e-4
SGLD_NUM_CHAINS = 4
SGLD_NUM_BURNIN = 200
SGLD_NUM_DRAWS = 200
SGLD_SEED = 42

# ==========================================
# MODELS
# ==========================================

class MLPConcat(nn.Module):
    def __init__(self, p, dim=128):
        super().__init__()
        self.embed = nn.Embedding(p, dim)
        self.linear = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, p)
        )
    def forward(self, x):
        e = self.embed(x).view(x.shape[0], -1)
        return self.linear(e)

class MLPAdd(nn.Module):
    def __init__(self, p, dim=128):
        super().__init__()
        self.p = p
        self.embed = nn.Embedding(p, dim)
        self.linear = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, p)
        )
    def forward(self, x):
        e = self.embed(x[:, 0]) + self.embed(x[:, 1])
        return self.linear(e)

# ==========================================
# CUSTOM ESTIMATORS
# ==========================================

class WBICEstimator(SamplerCallback):
    def __init__(self):
        super().__init__()
        self.losses = []
    def __call__(self, loss, **_kwargs): self.losses.append(loss.item())
    def get_results(self): return {"wbic": np.mean(self.losses) if self.losses else 0.0}

class FunctionalVarianceEstimator(SamplerCallback):
    def __init__(self, x_fixed):
        super().__init__()
        self.x_fixed = x_fixed
        self.outputs = []
    def __call__(self, model, **_kwargs):
        model.eval()
        with torch.no_grad():
            self.outputs.append(model(self.x_fixed).cpu().numpy())
    def get_results(self):
        stack = np.stack(self.outputs)
        return {"func_var": np.var(stack, axis=0).mean()}

def get_embedding_pca(model):
    weights = model.embed.weight.detach().cpu().numpy()
    pca = PCA(n_components=min(6, weights.shape[1]))
    return pca.fit_transform(weights)

# ==========================================
# RUNNER
# ==========================================

def run_experiment(use_symmetry: bool, multiply: bool, output_path: str):
    ModelClass = MLPAdd if use_symmetry else MLPConcat
    model_name = ("Multiply " if multiply else "") + ("MLP-Add (Symmetric)" if use_symmetry else "MLP-Concat (Standard)")
    print(f"\n=== Training {model_name} -> {output_path} ===")

    model = ModelClass(P).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.4)
    criterion = nn.CrossEntropyLoss()

    # Data
    pairs = torch.cartesian_prod(torch.arange(P), torch.arange(P))
    labels = (pairs[:, 0] * pairs[:, 1]) % P if multiply else (pairs[:, 0] + pairs[:, 1]) % P
    dataset = TensorDataset(pairs, labels)
    train_size = int(TRAIN_PCT * len(dataset))
    gen = torch.Generator().manual_seed(42)
    train_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, len(dataset)-train_size], generator=gen)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_x, test_y = dataset.tensors[0][test_ds.indices].to(DEVICE), dataset.tensors[1][test_ds.indices].to(DEVICE)

    history = {"train_acc": [], "test_acc": [], "llc": [], "llc_std": [], "wbic": [], "func_var": [], "embeddings": [], "epochs": []}

    for epoch in tqdm(range(EPOCHS)):
        model.train()
        correct = 0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x); loss = criterion(out, y)
            loss.backward(); optimizer.step(); optimizer.zero_grad()
            correct += (out.argmax(1) == y).sum().item()

        if epoch % ESTIMATION_INTERVAL == 0:
            model.eval()
            train_acc = correct / train_size
            test_acc = (model(test_x).argmax(1) == test_y).sum().item() / len(test_ds)

            nbeta = train_size / np.log(train_size)
            # init_loss=0.0 is a placeholder; sample() overrides it via setattr after
            # computing the loss at the current model parameters.
            llc_est = LLCEstimator(num_chains=SGLD_NUM_CHAINS, num_draws=SGLD_NUM_DRAWS, nbeta=nbeta, device=DEVICE, init_loss=0.0)
            wbic_est = WBICEstimator()
            fvar_est = FunctionalVarianceEstimator(test_x[:100])

            sample(model, train_loader,
                   num_chains=SGLD_NUM_CHAINS, num_draws=SGLD_NUM_DRAWS, num_burnin_steps=SGLD_NUM_BURNIN,
                   evaluate=lambda m, d: criterion(m(d[0].to(DEVICE)), d[1].to(DEVICE)),
                   callbacks=[llc_est, wbic_est, fvar_est],
                   optimizer_kwargs=dict(lr=SGLD_LR, nbeta=nbeta),
                   sampling_method=SGLD,
                   seed=SGLD_SEED, verbose=False)

            results = llc_est.get_results()
            llc = results["llc/mean"]
            llc_std = results.get("llc/std", 0.0)
            wbic = wbic_est.get_results()["wbic"]
            fvar = fvar_est.get_results()["func_var"]

            history["train_acc"].append(train_acc)
            history["test_acc"].append(test_acc)
            history["llc"].append(llc)
            history["llc_std"].append(llc_std)
            history["wbic"].append(wbic)
            history["func_var"].append(fvar)
            history["embeddings"].append(get_embedding_pca(model))
            history["epochs"].append(epoch)

    # ==========================================
    # ANIMATION
    # ==========================================
    fig = plt.figure(figsize=(16, 12))
    plt.suptitle(f"Analysis: {model_name}", fontsize=16)

    ax_pca1 = fig.add_axes([0.05, 0.65, 0.25, 0.25])
    ax_pca2 = fig.add_axes([0.37, 0.65, 0.25, 0.25])
    ax_pca3 = fig.add_axes([0.70, 0.65, 0.25, 0.25])

    ax_acc  = fig.add_axes([0.1, 0.42, 0.65, 0.15])
    ax_slt  = fig.add_axes([0.1, 0.24, 0.65, 0.15])
    ax_fvar = fig.add_axes([0.1, 0.08, 0.65, 0.12])

    initial_pca = history["embeddings"][0]
    colors = range(P)

    scat1 = ax_pca1.scatter(initial_pca[:, 0], initial_pca[:, 1], c=colors, cmap='Greys', edgecolors='k', s=40)
    ax_pca1.set_title("PCA 1 & 2")
    scat2 = ax_pca2.scatter(initial_pca[:, 2], initial_pca[:, 3], c=colors, cmap='Greys', edgecolors='k', s=40)
    ax_pca2.set_title("PCA 3 & 4")
    scat3 = ax_pca3.scatter(initial_pca[:, 4], initial_pca[:, 5], c=colors, cmap='Greys', edgecolors='k', s=40)
    ax_pca3.set_title("PCA 5 & 6")

    ax_acc.plot(history["epochs"], history["train_acc"], label="Train Acc", alpha=0.3, color='blue')
    ax_acc.plot(history["epochs"], history["test_acc"], label="Test Acc", color='green', linewidth=2)
    ax_acc.set_ylabel("Accuracy")
    ax_acc.legend(loc='lower right')

    epochs_arr = np.array(history["epochs"])
    llc_arr = np.array(history["llc"])
    std_arr = np.array(history["llc_std"])
    ax_slt.plot(epochs_arr, llc_arr, color='red', label="LLC (λ)")
    ax_slt.fill_between(epochs_arr, llc_arr - std_arr, llc_arr + std_arr, color='red', alpha=0.2)
    ax_slt.set_ylabel("LLC")
    ax_wbic = ax_slt.twinx()
    ax_wbic.plot(history["epochs"], history["wbic"], color='purple', linestyle='--', label="WBIC")
    ax_wbic.set_ylabel("WBIC")
    ax_slt.legend(loc='upper left'); ax_wbic.legend(loc='upper right')

    ax_fvar.plot(history["epochs"], history["func_var"], color='orange', label="Func Var")
    ax_fvar.set_ylabel("Var[f]")
    ax_fvar.set_xlabel("Epochs")

    vline1 = ax_acc.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5)
    vline2 = ax_slt.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5)
    vline3 = ax_fvar.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5)

    txt_acc = fig.text(0.85, 0.495, '', verticalalignment='center')
    txt_slt = fig.text(0.85, 0.315, '', verticalalignment='center')
    txt_fvar = fig.text(0.85, 0.14, '', verticalalignment='center')

    ax_slider = fig.add_axes([0.2, 0.01, 0.6, 0.02])
    slider = Slider(ax_slider, f"Epoch/{ESTIMATION_INTERVAL}", 0, len(history["epochs"]) - 1, valinit=len(history["epochs"])-1, valstep=1)

    def update(_val):
        idx = int(slider.val)
        epoch = history["epochs"][idx]
        pca_data = history["embeddings"][idx]

        scat1.set_offsets(pca_data[:, 0:2])
        ax_pca1.set_xlim(pca_data[:, 0].min() - 0.1, pca_data[:, 0].max() + 0.1)
        ax_pca1.set_ylim(pca_data[:, 1].min() - 0.1, pca_data[:, 1].max() + 0.1)

        scat2.set_offsets(pca_data[:, 2:4])
        ax_pca2.set_xlim(pca_data[:, 2].min() - 0.1, pca_data[:, 2].max() + 0.1)
        ax_pca2.set_ylim(pca_data[:, 3].min() - 0.1, pca_data[:, 3].max() + 0.1)

        scat3.set_offsets(pca_data[:, 4:6])
        ax_pca3.set_xlim(pca_data[:, 4].min() - 0.1, pca_data[:, 4].max() + 0.1)
        ax_pca3.set_ylim(pca_data[:, 5].min() - 0.1, pca_data[:, 5].max() + 0.1)

        vline1.set_xdata([epoch])
        vline2.set_xdata([epoch])
        vline3.set_xdata([epoch])

        txt_acc.set_text(f"Train Acc: {history['train_acc'][idx]:.4f}\nTest Acc:  {history['test_acc'][idx]:.4f}")
        txt_slt.set_text(f"LLC:  {history['llc'][idx]:.4f} ± {history['llc_std'][idx]:.2f}\nWBIC: {history['wbic'][idx]:.4f}")
        txt_fvar.set_text(f"Func Var: {history['func_var'][idx]:.4f}")

        fig.canvas.draw_idle()

    slider.on_changed(update)
    update(len(history["epochs"]) - 1)

    print(f"Saving animation to {output_path}...")
    num_frames = len(history["epochs"])
    duration_seconds = 10
    calculated_fps = max(1, num_frames // duration_seconds)

    def animate(i):
        slider.set_val(i)
        return []

    ani = FuncAnimation(fig, animate, frames=num_frames, interval=1000/calculated_fps)
    writer = FFMpegWriter(fps=calculated_fps, bitrate=1800)
    ani.save(output_path, writer=writer)
    print(f"Saved {output_path} at {calculated_fps} FPS")
    plt.close(fig)


if __name__ == "__main__":
    os.makedirs("generated", exist_ok=True)
    configs = [
        (False, False, "generated/modulo_add.mp4"),
        (True,  False, "generated/modulo_add_sym.mp4"),
        (False, True,  "generated/modulo_mul.mp4"),
        (True,  True,  "generated/modulo_mul_sym.mp4"),
    ]
    for use_symmetry, multiply, output_path in configs:
        run_experiment(use_symmetry, multiply, output_path)
