import sys
import os

from ..utils.preprocessing import preprocess_demo
import torch
import numpy as np
from torch.utils.data import Dataset


class TabularDataset(Dataset):

    def __init__(self, X_train, y_train, train_names, demo_path):
        super(TabularDataset, self).__init__()

        # demographic_data, diagnosis_data, idx_list = preprocess_demo(demo_path)
        # demo_features = []
        # diag_features = []

        # for i in range(len(train_names)):
        #     cur_id, cur_ep, _ = train_names[i].split('_', 2)
        #     cur_idx = cur_id + '_' + cur_ep
        #     cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)],
        #                             dtype=torch.float32)
        #     cur_dig = torch.tensor(diagnosis_data[idx_list.index(cur_idx)],
        #                            dtype=torch.long)

        #     demo_features.append(cur_demo)
        #     diag_features.append(cur_dig)

        # self.demo_features = np.vstack(demo_features)
        # self.diag_features = np.vstack(diag_features)

        self.X_train = X_train
        self.y_train = y_train
        self.train_names = train_names

        # print('  train data shape = {}'.format(self.X_train.shape))
        # print('  train label shape = {}'.format(self.y_train.shape))
        # (17903, 48, 76)

    def __getitem__(self, index):

        return torch.tensor(self.X_train[index],
                            dtype=torch.float32), torch.tensor(
                                self.y_train[index], dtype=torch.float32)

    def __len__(self):
        return self.X_train.shape[0]
