#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/models/baseline.py
# Project: /home/richard/projects/DeepOrchestration/models
# Created Date: Wednesday, May 29th 2024, 6:33:56 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Thu May 30 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import softmax, cosine_similarity, mse_loss, relu, sigmoid


# Dynamic Gating Feature Selection Module (DGFS)
class DGFS(nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super(DGFS, self).__init__()
        self.fc = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        x = self.fc(x)
        return x


# Attention-based Transformer Autoencoder (ATA)
class ATA(nn.Module):

    def __init__(self, hidden_dim, num_layers, nhead):
        super(ATA, self).__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim,
                                                        nhead=nhead)
        self.encoder = nn.TransformerEncoder(self.encoder_layer,
                                             num_layers=num_layers)

    def forward(self, x):
        encoded = self.encoder(x)
        return encoded


# DeepOr Framework
class DeepOrBase(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_layers, nhead):
        super(DeepOrBase, self).__init__()

        self.dgfs = DGFS(input_dim, hidden_dim)
        self.ata = ATA(hidden_dim, num_layers, nhead)
        self.prediction_head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                             nn.ReLU(),
                                             nn.Linear(hidden_dim, 1),
                                             nn.Sigmoid())

    def forward(self, x):
        xs = self.dgfs(x)
        zr = self.ata(xs)
        # residule connection
        h, _ = torch.max(zr, dim=1)
        output = self.prediction_head(h)

        return output


# Example usage
if __name__ == "__main__":
    input_dim = 76  # Feature dimension
    hidden_dim = 128
    num_layers = 2
    nhead = 4

    model = DeepOrBase(input_dim, hidden_dim, num_layers, nhead)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # Dummy input for demonstration
    batch_size = 32
    seq_length = 48
    x = torch.rand(batch_size, seq_length,
                   input_dim)  # Assuming [batch_size, 48, 76]

    output = model(x)

    # create target tensor with size 32
    target = torch.randint(0, 2, (batch_size, ))

    print("Target:", target.shape)
    print("Output:", output.shape)

    loss = criterion(output, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("Output:", output)
    print("Loss:", loss.item())
