#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/models/rml.py
# Project: /home/richard/projects/DeepOrchestration/models
# Created Date: Thursday, May 30th 2024, 10:12:57 am
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Thu May 30 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import torch
import torch.nn as nn
import torch.nn.functional as F


class RML(nn.Module):

    def __init__(self, latent_dim):
        super(RML, self).__init__()

        self.linear1 = nn.Linear(latent_dim * 2, latent_dim)
        self.linear2 = nn.Linear(latent_dim, latent_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, z_dgfs, z_ata):
        # Ensure the shapes of z_dgfs and z_ata are the same
        assert z_dgfs.shape == z_ata.shape, "Shapes of DGFS and ATA outputs must match."

        concat_repr = torch.cat((z_dgfs, z_ata), dim=1)
        # Combine representations
        supportive = self.sigmoid(self.linear1(concat_repr)) * (z_dgfs + z_ata)

        complementary = F.relu(self.linear2(z_ata - z_dgfs))
        combined = torch.cat((supportive, complementary), dim=1)

        return combined


if __name__ == "__main__":
    from dgfs import DGFS
    from ata import ATA

    input_dim = 76
    time_steps = 48
    latent_dim = 128
    batch_size = 32

    dgfs = DGFS(input_dim, time_steps, latent_dim)
    ata = ATA(input_dim, time_steps, latent_dim)
    x = torch.randn(batch_size, time_steps, input_dim)
    latent_dgfs, selected_features, attention_prior = dgfs(x)
    output, latent_ata = ata(x, attention_prior)

    print(latent_dgfs.shape)
    print(latent_ata.shape)

    rml = RML(latent_dim)

    combinded = rml(latent_dgfs, latent_ata)

    print(combinded.shape)
