#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/models/dgfs.py
# Project: /home/richard/projects/DeepOrchestration/models
# Created Date: Thursday, May 30th 2024, 12:02:20 am
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Sat Jun 08 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import torch
import torch.nn as nn
import torch.nn.functional as F


def sample_gumbel(shape, eps=1e-20, device="cuda"):
    U = torch.rand(shape).to(device)
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature, device):
    y = logits + sample_gumbel(logits.size(), device=device)
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature, threshold=0.9, hard=False):
    """
    ST-gumbel-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """

    y_soft = gumbel_softmax_sample(logits, temperature, device=logits.device)

    if not hard:
        return y_soft, None

    if hard:
        sorted_probs, sorted_indices = y_soft.sort(dim=-1, descending=True)
        cum_probs = sorted_probs.cumsum(dim=-1)
        k = ((cum_probs < threshold).sum(dim=-1, keepdim=True) + 1).item()
        topk_indices = sorted_indices[:k]
        y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
        sparse_y = y_hard - y_soft.detach() + y_soft
    else:
        sparse_y = y_soft

    return sparse_y, y_soft


import math


def sigmoid(x):
    return 1 / (1 + math.exp(-x))


class SparsityController:

    def __init__(self, temperature, setpoint, Kp, Ki, Kd, min_temp, max_temp):
        self.Kp = Kp
        self.Ki = Ki
        self.Kd = Kd
        self.setpoint = setpoint
        self.temperature = temperature
        self.min_temp = min_temp
        self.max_temp = max_temp

        self.previous_error = 0
        self.integral = 0

    def update(self, current_loss):
        # Calculate error
        error = self.setpoint - current_loss

        # Proportional term
        P = self.Kp * error

        # Integral term
        self.integral += error
        # set limit for integral term in range(-1,1)
        self.integral = max(-1, min(1, self.integral))

        I = self.Ki * self.integral

        # Derivative term
        derivative = error - self.previous_error
        D = self.Kd * derivative

        # Update previous error
        self.previous_error = error

        # print("P:", P, "I:", I, "D:", D)

        # Calculate output
        output = P + I + D

        self.temperature += output

        # Ensure the process variable stays within desired range
        self.temperature = max(self.min_temp,
                               min(self.max_temp, self.temperature))

    def set_setpoint(self, setpoint):
        self.setpoint = setpoint

    def reset(self):
        self.previous_error = 0
        self.integral = 0


class DGFS(nn.Module):

    def __init__(self,
                 input_dim,
                 time_steps,
                 latent_dim,
                 temperature=1.0,
                 setpoint=0,
                 Kp=0.001,
                 Ki=0.001,
                 Kd=0.01,
                 min_temp=0.4,
                 max_temp=1.0):
        super(DGFS, self).__init__()
        self.input_dim = input_dim
        self.time_steps = time_steps
        self.latent_dim = latent_dim
        self.feature_weights = nn.Parameter(torch.randn(input_dim))
        self.sparsity_controller = SparsityController(temperature=temperature,
                                                      setpoint=setpoint,
                                                      Kp=Kp,
                                                      Ki=Ki,
                                                      Kd=Kd,
                                                      min_temp=min_temp,
                                                      max_temp=max_temp)
        self.latent_projection = nn.Linear(self.time_steps * self.input_dim,
                                           self.latent_dim)

        self.mlp_layer = nn.Sequential(nn.Linear(self.input_dim,
                                                 64), nn.ReLU(),
                                       nn.Linear(64, self.input_dim))

    def forward(self, x):
        batch_size, time_steps, _ = x.size()

        sparse_weights, weights = gumbel_softmax(
            self.feature_weights,
            self.sparsity_controller.temperature,
            hard=True)

        sparse_weights_expanded = sparse_weights.unsqueeze(0).unsqueeze(
            1).expand(batch_size, time_steps, -1)

        # Apply sparse weights to input features
        selected_features = x * sparse_weights_expanded

        # Apply MLP and residual connection
        transformed_features = F.relu(self.mlp_layer(selected_features))

        latent_repr = transformed_features.view(batch_size, -1)
        latent_repr = self.latent_projection(latent_repr)

        return latent_repr, sparse_weights

    def predict(self, x, n_features=10):
        batch_size, time_steps, _ = x.size()

        weights = self.feature_weights
        topk_index = torch.topk(weights, n_features).indices
        sparse_weights = torch.zeros_like(weights)
        sparse_weights[topk_index] = 1

        sparse_weights_expanded = sparse_weights.unsqueeze(0).unsqueeze(
            1).expand(batch_size, time_steps, -1)

        # Apply sparse weights to input features
        selected_features = x * sparse_weights_expanded

        # Apply MLP and residual connection
        transformed_features = F.relu(self.mlp_layer(selected_features))

        latent_repr = transformed_features.view(batch_size, -1)
        latent_repr = self.latent_projection(latent_repr)

        return latent_repr, sparse_weights


# Testing the DGFS module
if __name__ == "__main__":
    # Define the input dimensions
    batch_size = 32
    input_dim = 20
    time_step = 48
    latent_dim = 128

    device = "cuda:0"
    # Initialize the DGFS module
    dgfs = DGFS(input_dim,
                time_step,
                latent_dim,
                temperature=1.0,
                setpoint=0,
                Kp=0.001,
                Ki=0.001,
                Kd=0.01,
                min_temp=1.0,
                max_temp=1.0)
    dgfs.to(device)

    # Generate a random input tensor
    x = torch.randn(batch_size, time_step, input_dim)  # Batch size of 32
    x = x.to(device)

    # Forward pass
    output, sparse_weights, weights = dgfs(x)

    print("Input shape:", x.shape)
    print("Output shape:", output.shape)

    print(sparse_weights)
