#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/models/ata.py
# Project: /home/richard/projects/DeepOrchestration/models
# Created Date: Thursday, May 30th 2024, 9:21:37 am
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Sat Jun 08 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerDecoder, TransformerDecoderLayer


class CustomMultiheadAttention(nn.Module):

    def __init__(self, embed_dim, dropout=0.0, bias=True):
        super(CustomMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.dropout = dropout

        self.q_proj_weight = nn.Parameter(torch.empty(embed_dim, embed_dim))
        self.k_proj_weight = nn.Parameter(torch.empty(embed_dim, embed_dim))
        self.v_proj_weight = nn.Parameter(torch.empty(embed_dim, embed_dim))

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj_weight)
        nn.init.xavier_uniform_(self.k_proj_weight)
        nn.init.xavier_uniform_(self.v_proj_weight)
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0.)

    def forward(self, query, key, value, attention_prior=None):
        q = F.linear(query, self.q_proj_weight)
        k = F.linear(key, self.k_proj_weight)
        v = F.linear(value, self.v_proj_weight)

        if attention_prior is not None:
            attention_prior = attention_prior.unsqueeze(0).unsqueeze(0)
            q = q * attention_prior
            k = k * attention_prior
            v = v * attention_prior

        attn_output, attn_output_weights = self.scaled_dot_product_attention(
            q, k, v)
        attn_output = F.linear(attn_output, self.out_proj.weight,
                               self.out_proj.bias)

        return attn_output, attn_output_weights

    def scaled_dot_product_attention(self, q, k, v):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(d_k, dtype=torch.float32))
        attn_output_weights = F.softmax(scores, dim=-1)
        attn_output_weights = self.attn_dropout(attn_output_weights)
        attn_output = torch.matmul(attn_output_weights, v)
        return attn_output, attn_output_weights


class CustomTransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1):
        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = CustomMultiheadAttention(d_model, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = F.relu

    def forward(self,
                src,
                src_mask=None,
                src_key_padding_mask=None,
                attention_prior=None):
        src2, _ = self.self_attn(src,
                                 src,
                                 src,
                                 attention_prior=attention_prior)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class CustomTransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers):
        super(CustomTransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])

    def forward(self,
                src,
                mask=None,
                src_key_padding_mask=None,
                attention_prior=None):
        output = src
        for mod in self.layers:
            output = mod(output,
                         src_mask=mask,
                         src_key_padding_mask=src_key_padding_mask,
                         attention_prior=attention_prior)
        return output


class ATA(nn.Module):

    def __init__(self,
                 input_dim,
                 time_steps,
                 decoder_nhead=4,
                 num_encoder_layers=3,
                 num_decoder_layers=3,
                 dim_feedforward=256):
        super(ATA, self).__init__()
        self.input_dim = input_dim
        self.time_steps = time_steps
        self.decoder_nhead = decoder_nhead
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward

        encoder_layer = CustomTransformerEncoderLayer(
            d_model=input_dim, dim_feedforward=dim_feedforward)
        self.transformer_encoder = CustomTransformerEncoder(
            encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = TransformerDecoderLayer(
            d_model=input_dim,
            nhead=self.decoder_nhead,
            dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(
            decoder_layer, num_layers=num_decoder_layers)

        self.positional_encoding = PositionalEncoding(input_dim,
                                                      max_len=time_steps)

    def forward(self, x, attention_prior=None):
        _, time_steps, input_dim = x.size()
        assert time_steps == self.time_steps and input_dim == self.input_dim

        x = self.positional_encoding(x)

        memory = self.transformer_encoder(x, attention_prior=attention_prior)

        output = self.transformer_decoder(x, memory)

        return output


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=0.1)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) *
            -(torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class ATA(nn.Module):

    def __init__(self,
                 input_dim,
                 time_steps,
                 latent_dim,
                 decoder_nhead,
                 num_encoder_layers=2,
                 num_decoder_layers=2,
                 dim_feedforward=128):
        super(ATA, self).__init__()
        self.input_dim = input_dim
        self.time_steps = time_steps
        self.latent_dim = latent_dim
        self.decoder_nhead = decoder_nhead
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward

        encoder_layer = CustomTransformerEncoderLayer(
            d_model=input_dim, dim_feedforward=dim_feedforward)
        self.transformer_encoder = CustomTransformerEncoder(
            encoder_layer, num_layers=num_encoder_layers)

        self.latent_fc = nn.Linear(input_dim * time_steps, latent_dim)

        self.decoder_input_fc = nn.Linear(latent_dim, input_dim * time_steps)

        decoder_layer = TransformerDecoderLayer(
            d_model=input_dim,
            nhead=decoder_nhead,
            dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(
            decoder_layer, num_layers=num_decoder_layers)

        self.positional_encoding = PositionalEncoding(input_dim,
                                                      max_len=time_steps)

    def forward(self, x, attention_prior=None):
        batch_size, time_steps, input_dim = x.size()
        assert time_steps == self.time_steps and input_dim == self.input_dim

        x = self.positional_encoding(x)

        # Encoder
        encoded = self.transformer_encoder(x, attention_prior=attention_prior)
        encoded_flat = encoded.view(batch_size, -1)
        latent = self.latent_fc(encoded_flat)

        # Decoder
        decoder_input = self.decoder_input_fc(latent).view(
            batch_size, time_steps, input_dim)
        decoder_input = self.positional_encoding(decoder_input)

        decoded = self.transformer_decoder(decoder_input, encoded)

        return decoded, latent


# Testing the ATA module
if __name__ == "__main__":
    from dgfs import DGFS

    input_dim = 76
    time_steps = 48
    latent_dim = 128
    batch_size = 32

    device = "cuda:0"

    x = torch.randn(batch_size, time_steps, input_dim).to(device)

    dgfs = DGFS(input_dim, time_steps, latent_dim).to(device)
    ata = ATA(input_dim, time_steps, latent_dim).to(device)

    rpr, selected_features, attention_prior = dgfs(x)
    output, latent = ata(x, attention_prior)

    print("Input shape:", x.shape)
    # print("Output shape:", output.shape)
    print("Latent shape:", latent.shape)
    print("rpr shape:", rpr.shape)
