#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /data/zhangruochi/projects/AAAI/FairCare/utils/preprocessing.py
# Project: /mnt/data/zhangruochi/DeepOrchestration/utils
# Created Date: Saturday, January 22nd 2022, 5:30:49 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Mon Sep 29 2025
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2022 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2022 Ruochi Zhang Ltd
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###
import os
import sys

sys.path.insert(
    0,
    os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
                 "benchmarks"))

from mimic3models.in_hospital_mortality import utils as mortality_utils
from mimic3models.preprocessing import Discretizer, Normalizer
from mimic3models import common_utils
from mimic3benchmark.readers import InHospitalMortalityReader

from pathlib import Path
from typing import List
import numpy as np
import pickle as pkl


def preprocess_t_mimic(dataset_path: Path, cache_dir: Path,
                       normalizer_state: str) -> List:

    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    cache_file = cache_dir / "mimic3.pkl"

    if cache_file.exists():
        with open(cache_file, "rb") as f:
            X_train, y_train, train_names, X_valid, y_valid, valid_names, X_test, y_test, test_names = pkl.load(
                f)
    else:
        train_reader = InHospitalMortalityReader(
            dataset_dir=os.path.join(dataset_path, 'train'),
            listfile=os.path.join(dataset_path, 'train/listfile.csv'),
            period_length=48.0)

        val_reader = InHospitalMortalityReader(
            dataset_dir=os.path.join(dataset_path, 'test'),
            listfile=os.path.join(dataset_path, 'test/listfile.csv'),
            period_length=48.0)

        test_reader = InHospitalMortalityReader(
            dataset_dir=os.path.join(dataset_path, 'test'),
            listfile=os.path.join(dataset_path, 'test/listfile.csv'),
            period_length=48.0)

        discretizer = Discretizer(timestep=1.0,
                                  store_masks=True,
                                  impute_strategy='previous',
                                  start_time='zero')

        discretizer_header = discretizer.transform(
            train_reader.read_example(0)["X"])[1].split(',')

        cont_channels = [
            i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
        ]
        # choose here which columns to standardize
        normalizer = Normalizer(fields=cont_channels)
        # Read data
        normalizer_state = os.path.join(os.path.dirname(__file__),
                                        normalizer_state)
        normalizer.load_params(normalizer_state)

        print('Reading data and extracting features ...')

        train_raw = mortality_utils.load_data(train_reader, discretizer,
                                              normalizer, False, True)
        valid_raw = mortality_utils.load_data(val_reader, discretizer,
                                              normalizer, False, True)
        test_raw = mortality_utils.load_data(test_reader, discretizer,
                                             normalizer, False, True)

        # (3222, 48, 76) -> [batch_size, T, features]
        X_train, y_train = train_raw["data"][0], np.array(train_raw["data"][1])
        X_valid, y_valid = valid_raw["data"][0], np.array(valid_raw["data"][1])
        X_test, y_test = test_raw["data"][0], np.array(test_raw["data"][1])

        train_names, valid_names, test_names = train_raw["names"], valid_raw[
            "names"], test_raw["names"]

        with open(cache_file, "wb") as f:
            pkl.dump([
                X_train, y_train, train_names, X_valid, y_valid, valid_names,
                X_test, y_test, test_names
            ], f)

    print('  train data shape = {}'.format(X_train.shape))
    print('  validation data shape = {}'.format(X_valid.shape))
    print('  test data shape = {}'.format(X_test.shape))

    return X_train, y_train, train_names, X_valid, y_valid, valid_names, X_test, y_test, test_names


def preprocess_demo(demo_path: str):
    demographic_data = []
    diagnosis_data = []
    idx_list = []

    for cur_name in os.listdir(demo_path):
        cur_id, cur_episode = cur_name.split('_', 1)
        cur_episode = cur_episode[:-4]
        cur_file = demo_path + cur_name

        with open(cur_file, "r") as tsfile:
            header = tsfile.readline().strip().split(',')
            if header[0] != "Icustay":
                continue
            cur_data = tsfile.readline().strip().split(',')

        if len(cur_data) == 1:
            cur_demo = np.zeros(12)
            cur_diag = np.zeros(128)
        else:
            if cur_data[3] == '':
                cur_data[3] = 60.0
            if cur_data[4] == '':
                cur_data[4] = 160
            if cur_data[5] == '':
                cur_data[5] = 60

            cur_demo = np.zeros(12)
            # print(cur_data)
            cur_demo[int(cur_data[1])] = 1
            cur_demo[5 + int(cur_data[2])] = 1
            cur_demo[9:] = cur_data[3:6]
            cur_diag = np.array(cur_data[8:], dtype=np.int32)

        demographic_data.append(cur_demo)
        diagnosis_data.append(cur_diag)
        idx_list.append(cur_id + '_' + cur_episode)

    for each_idx in range(9, 12):
        cur_val = []
        for i in range(len(demographic_data)):
            cur_val.append(demographic_data[i][each_idx])
        cur_val = np.array(cur_val)
        _mean = np.mean(cur_val)
        _std = np.std(cur_val)
        _std = _std if _std > 1e-7 else 1e-7
        for i in range(len(demographic_data)):
            demographic_data[i][each_idx] = (demographic_data[i][each_idx] -
                                             _mean) / _std

    return demographic_data, diagnosis_data, idx_list
