# training/ignite_note.py

import os
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR
from pathlib import Path
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, EarlyStopping, ModelCheckpoint
from ignite.metrics import Average
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.contrib.handlers.wandb_logger import WandBLogger
import wandb

from dl.models.classes import MultiClassOnsetClassifier
from training.class_loader import ClassBaseLoader
from utils import setup_checkpoint_upload


def score_function(engine: Engine):
    val_loss = engine.state.metrics["loss"]
    return -val_loss


def ignite_train(
    train_dataset: ClassBaseLoader,
    valid_dataset: ClassBaseLoader,
    model: MultiClassOnsetClassifier,
    train_loader,
    valid_loader,
    optimizer: Optimizer,
    train_dataset_len,
    valid_dataset_len,
    device,
    lr_scheduler: CyclicLR | CosineAnnealingLR,
    wandb_logger: WandBLogger,
    **run_parameters,
):
    epochs = run_parameters.get("epochs", 100)
    epoch_length = run_parameters.get("epoch_length", 100)
    checkpoint_interval = run_parameters.get("checkpoint_interval", 100)
    validation_interval = run_parameters.get("validation_interval", 100)
    warmup_steps = run_parameters.get("warmup_steps", 0)
    wandb_mode = run_parameters.get("wandb_mode", "online")
    n_saved_model = run_parameters.get("n_saved_model", 10)
    n_saved_checkpoint = run_parameters.get("n_saved_checkpoint", 10)
    resume_checkpoint = run_parameters.get("resume_checkpoint", None)
    batch_size = run_parameters.get("batch_size", 32)
    target_lr = optimizer.param_groups[0]["lr"]

    def cycle(dataloader):
        while True:
            for file in dataloader:
                for batch in train_dataset.process(file):
                    yield batch

    def train_step(engine: Engine, batch):
        model.train()
        if warmup_steps > 0 and engine.state.iteration < warmup_steps:
            lr_scale = min(1.0, float(engine.state.iteration + 1) / float(warmup_steps))
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * target_lr

        optimizer.zero_grad()
        preds, losses = model.run_on_batch(batch)
        loss = losses["loss"]
        loss.backward()
        optimizer.step()
        if lr_scheduler and engine.state.iteration >= warmup_steps:
            lr_scheduler.step()

        return preds, {k: v.item() for k, v in losses.items()}

    def eval_step(engine: Engine, batch):
        model.eval()
        with torch.no_grad():
            preds, losses = model.run_on_batch(batch)
            return preds, {k: v.item() for k, v in losses.items()}

    trainer = Engine(train_step)
    evaluator = Engine(eval_step)

    # Attach metrics
    avg_loss = Average(output_transform=lambda output: output[1]["loss"])
    avg_loss.attach(trainer, "loss")
    avg_loss.attach(evaluator, "loss")

    # Logging
    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def log_validation(engine: Engine):
        evaluator.run(
            cycle(valid_loader),
            epoch_length=valid_dataset_len // batch_size,
        )

        metrics = evaluator.state.metrics
        epoch = engine.state.epoch

        print(f"[Validation] Epoch {epoch} - Loss: {metrics['loss']:.4f}")
        if wandb_mode != "disabled":
            for k, v in metrics.items():
                wandb_logger.log({f"validation/{k}": v, "epoch": epoch})

    # Checkpointing
    to_save = {"model": model, "optimizer": optimizer}
    if wandb_mode != "disabled":

        if resume_checkpoint:
            checkpoint_path = os.path.join(wandb.run.dir, "checkpoints", f"checkpoint_<X>.pth")  # type: ignore
            Checkpoint.load_objects(
                to_load={"model": model, "optimizer": optimizer},
                checkpoint=torch.load(checkpoint_path),
            )

        handler = Checkpoint(
            to_save,
            DiskSaver(os.path.join(wandb.run.dir, "checkpoints"), create_dir=True, require_empty=False),  # type: ignore
            n_saved=n_saved_checkpoint,
        )
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=1),
            handler,
        )

        best_checkpoint = Checkpoint(
            to_save,
            DiskSaver(os.path.join(wandb.run.dir), create_dir=False, require_empty=False),  # type: ignore
            n_saved=2,
            score_function=score_function,
            score_name="validation_loss",
            greater_or_equal=True,
        )

        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), best_checkpoint)

        best_model = ModelCheckpoint(
            dirname=os.path.join(wandb.run.dir),  # type: ignore
            filename_prefix="model",
            n_saved=2,
            create_dir=True,
            require_empty=False,
            score_function=score_function,
            score_name="validation_loss",
            greater_or_equal=True,
        )
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED(every=1),
            best_model,
            {"mymodel": model},
        )

        model_handler = ModelCheckpoint(
            dirname=os.path.join(wandb.run.dir, "model_checkpoints"),  # type: ignore
            filename_prefix="model",
            n_saved=n_saved_model,
            create_dir=True,
            require_empty=False,
        )
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=1),
            model_handler,
            {"mymodel": model},
        )

        wandb.watch(model, log="all", criterion=avg_loss)

        setup_checkpoint_upload(trainer, {"model": model, "optimizer": optimizer}, wandb.run.dir, validation_interval=validation_interval)  # type: ignore

    # Progress bars
    ProgressBar(persist=True).attach(trainer, output_transform=lambda x: x[1]["loss"])
    ProgressBar(persist=True).attach(evaluator, output_transform=lambda x: x[1]["loss"])

    for epoch in range(epochs):
        trainer.run(
            cycle(train_loader),
            max_epochs=10,
            epoch_length=train_dataset_len // batch_size,
        )

    if wandb_mode != "disabled":
        wandb_logger.close()
