#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /root/CAMP/trainer.py
# Project: /home/richard/projects/DeepOrchestration/utils
# Created Date: Saturday, July 30th 2022, 4:03:31 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Sat Jun 08 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2022 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2022 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import os
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
import nni
from tqdm import tqdm
import torch


class Trainer(object):
    """Trainer for training and evaluating a model.
    """

    def __init__(self, net, criterion, dataloaders, optimizer, scheduler,
                 metrics, global_rank, cfg, logger, device):

        self.net = net
        self.criterion = criterion
        self.dataloaders = dataloaders
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.num_epoch = cfg.train.num_epoch
        self.global_rank = global_rank
        self.device = device
        self.default_metric = cfg.train.default_metric
        self.best_metric = 0

        self.cfg = cfg
        self.logger = logger
        self.optimizer = optimizer

        self.epoch = 0
        self.global_train_step = 0
        self.global_valid_eval_epoch = 0
        self.global_train_eval_epoch = 0
        self.global_test_eval_epoch = 0

        self.metrics_func = metrics
        self.best_model_path = Path(".")

        self.root_level_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

        self.feature_list = []

    def evaluate(self, split):
        """Evaluate the model on the given split.
           Args:
                split (str): the split (train, valid, test) to evaluate on.
            Outputs:
                metrics (dict): the metrics on the given split.
        """

        self.net.eval()

        y_true_list = []
        y_pred_list = []

        loss_dict = defaultdict(list)

        with torch.no_grad():

            for step, batch_data in tqdm(enumerate(self.dataloaders[split]),
                                         desc="evaluating {} ".format(split)):

                X, y = batch_data
                X = X.to(self.device)
                y = y.to(self.device)

                if self.cfg.model.name == "baseline":
                    pred_logit = self.net(X)
                    loss = self.criterion(pred_logit, y)

                elif self.cfg.model.name == "deepor":
                    pred_logit, selected_features, reconstruct_repr, latent_ata, latent_dgfs = self.net(
                        X)
                    loss = self.criterion(X, pred_logit, reconstruct_repr,
                                          latent_ata, latent_dgfs, y)

                loss_dict["loss"].append(loss["total_loss"].item())

                y_true_list.extend(y.cpu().numpy().tolist())
                y_pred_list.extend(
                    pred_logit.squeeze(-1).cpu().numpy().tolist())

        metrics = self.metrics_func(np.array(y_pred_list),
                                    np.array(y_true_list))

        tmp_res_dict = {}
        for k, v in loss_dict.items():
            tmp_res_dict["{}_{}".format(split, k)] = np.mean(v)

        for k, v in metrics.items():
            tmp_res_dict["{}_{}".format(split, k)] = v

        res_dict = tmp_res_dict
        return res_dict

    def eval_epoch(self, split):
        """Evaluate the model on the given split. Log the metrics. Save the best model.
        """
        if self.global_rank == 0:
            self.logger.std_print("-" * 10 + "evaluating {}".format(split) +
                                  "-" * 10)

        metrics = self.evaluate(split)

        if split == "test":
            self.global_test_eval_epoch += 1
            step = self.global_test_eval_epoch
        elif split == "valid":
            self.global_valid_eval_epoch += 1
            step = self.global_valid_eval_epoch
        elif split == "train":
            self.global_train_eval_epoch += 1
            step = self.global_train_eval_epoch

        if not self.cfg.mode.nni and self.global_rank == 0:
            for metric_name, metric_v in metrics.items():
                self.logger.log_metric(metric_name, metric_v, step=step)

        if split == "valid":

            if self.cfg.mode.nni:
                metrics["default"] = metrics["{}_{}".format(
                    split, self.default_metric)]
                nni.report_intermediate_result(metrics)

            # if self.cfg.train.lr_scheduler.when == "epoch" and self.cfg.train.lr_scheduler.type in (
            #         "plateau", ):
            #     self.scheduler.step(metrics["{}_{}".format(
            #         split, self.default_metric)])

            current_metric = metrics["{}_{}".format(split,
                                                    self.default_metric)]

            if current_metric > self.best_metric:
                self.best_metric = current_metric

                # self.best_model_path = Path("model_step_{}_{}_{}".format(
                #     self.global_valid_eval_epoch, self.default_metric,
                #     round(metrics["{}_{}".format(split, self.default_metric)],
                #           3)))

                # if not self.cfg.mode.nni and self.global_rank == 0:
                #     save_path = self.best_model_path
                #     code_path = [os.path.join(self.cfg.orig_cwd, "models")]
                #     self.logger.save_model(self.net, save_path, code_path)

        return metrics

    def train_epoch(self, epoch):
        """Train the model for one epoch.
        """

        if self.global_rank == 0:
            self.logger.std_print("-" * 10 +
                                  "training epoch {}".format(epoch) + "-" * 10)

        for _, batch_data in enumerate(self.dataloaders["train"]):
            self.net.train()

            X, y = batch_data
            X = X.to(self.device)
            y = y.to(self.device)

            if self.cfg.model.name == "baseline":

                pred_logit = self.net(X)
                loss = self.criterion(pred_logit, y)

            elif self.cfg.model.name == "deepor":
                pred_logit, selected_features, reconstruct_repr, latent_ata, latent_dgfs = self.net(
                    X)
                loss = self.criterion(X, pred_logit, reconstruct_repr,
                                      latent_ata, latent_dgfs, y)

            loss["total_loss"].backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            if self.cfg.model.name == "deepor":
                error = loss["prediction_loss"].item(
                ) + loss["consistency_loss"].item()
                self.net.dgfs.sparsity_controller.update(error)

            if self.cfg.train.lr_scheduler.when == "batch" and self.cfg.train.lr_scheduler.type in (
                    "cosine",
                    "cycle",
            ):
                self.scheduler.step()

            if not self.cfg.mode.nni and self.global_rank == 0 and self.global_train_step % self.cfg.logger.log_per_steps == 0:

                cur_lr = self.scheduler.optimizer.state_dict(
                )['param_groups'][0]['lr']

                for k, v in loss.items():
                    self.logger.log_metric("train/{}".format(k),
                                           v.item(),
                                           step=self.global_train_step)

                self.logger.log_metric("lr",
                                       float(cur_lr),
                                       step=self.global_train_step)

                if self.cfg.model.name == "deepor":
                    # log the number of features
                    self.logger.log_metric("num_features",
                                           selected_features.sum().item(),
                                           step=self.global_train_step)

                    # log temperature
                    self.logger.log_metric(
                        "temperature",
                        self.net.dgfs.sparsity_controller.temperature,
                        step=self.global_train_step)

                    # log the seletectd features
                    feature_list = selected_features.detach().cpu().numpy(
                    ).tolist()
                    self.feature_list.append(feature_list)

            self.global_train_step += 1

        torch.cuda.empty_cache()

    def run(self):
        """ Run the training process.
        """

        for epoch in range(1, self.num_epoch + 1):

            # log the current step
            self.epoch = epoch
            self.train_epoch(self.epoch)
            self.eval_epoch("train")
            metrics = self.eval_epoch("valid")

            if self.cfg.mode.nni:
                nni.report_intermediate_result(
                    {"default": metrics["valid_auroc"]})

        if self.cfg.mode.nni:
            nni.report_final_result({"default": self.best_metric})

        with open("feature_list.txt", "w") as f:
            for item in self.feature_list:
                f.write("%s\n" % item)
