# DeepSelective: Interpretable Prognosis Prediction via Feature Selection and Compression in EHR Data

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.2.0-red.svg)](https://pytorch.org/)

DeepSelective is a novel end-to-end deep learning framework designed for clinical prognosis prediction using Electronic Health Records (EHR) data. The framework combines feature selection and deep representation learning to achieve both high predictive accuracy and strong interpretability, making it a valuable tool for clinical decision-making.

## 🚀 Key Features

- **Dual Compression Design**: Integrates sparsity compression (DGFS) and perceptual compression (ATA) for optimal feature representation
- **Interpretable Feature Selection**: Uses Gumbel-Softmax with dynamic sparsity control to select the most informative features
- **End-to-End Training**: All components are jointly optimized for maximum performance
- **Clinical Interpretability**: Provides explicit feature importance and clinical insights
- **State-of-the-Art Performance**: Achieves superior results on MIMIC-III and MIMIC-IV benchmarks

## 🏗️ Architecture

DeepSelective consists of three core modules:

### 1. Dynamic Gate Feature Selection (DGFS) Module
- **Purpose**: Performs sparsity compression by selecting the most informative subset of features
- **Key Components**:
  - Gating Network: Generates probabilistic feature selection vectors using Gumbel-Softmax
  - SparsityController: Dynamically adjusts temperature parameter using PID control
- **Benefits**: Reduces dimensionality while preserving critical information

### 2. Attentive Transformer Autoencoder (ATA) Module
- **Purpose**: Performs perceptual compression through deep representation learning
- **Key Components**:
  - Encoder: Maps sparse features to dense latent representations
  - Decoder: Reconstructs original features from latent space
  - Attention Prior: Integrates feature importance from DGFS module
- **Benefits**: Captures complex inter-feature patterns and relationships

### 3. Representation Matching Layer (RML) Module
- **Purpose**: Harmonizes representations from DGFS and ATA modules
- **Key Operations**:
  - Aligned Information (`r_add`): Captures shared information between representations
  - Complementary Information (`r_sub`): Extracts unique information from each representation
- **Benefits**: Maximizes consistency and extracts complementary information

## 📊 Performance Results

### MIMIC-III Benchmark Results

| Task | Method | AUROC | AUPRC | min(Se,P+) |
|------|--------|-------|-------|------------|
| **Mortality Prediction** | DeepSelective | **0.9054** | **0.5627** | **0.5218** |
| | ConCare | 0.8702 | 0.5317 | 0.5082 |
| | GRU_α | 0.8628 | 0.4989 | 0.5026 |
| **Decompensation Prediction** | DeepSelective | **0.9143** | **0.3203** | **0.3573** |
| | AdaCare | 0.9004 | 0.3037 | 0.3429 |
| | GRU_α | 0.8983 | 0.2784 | 0.3260 |

### MIMIC-IV Benchmark Results

| Task | DeepSelective | Best Baseline |
|------|---------------|---------------|
| Mortality | 0.8325 | 0.8388 (Llemr) |
| Readmission | **0.7343** | 0.7251 (Llemr) |
| Length-of-Stay | **0.7149** | 0.7132 (Llemr) |
| Diagnosis | **0.8257** | 0.8128 (REMed) |

## 🛠️ Installation

### Prerequisites
- Python 3.8+
- CUDA 11.8+ (for GPU acceleration)
- Conda package manager

### Environment Setup

1. **Create conda environment**:
```bash
conda env create -f environment.yaml
conda activate deepor
```

1. **Configure MLFlow (optional)**:
```bash
conda env config vars set MLFLOW_TRACKING_URI="http://your-mlflow-server:5005/"
conda env config vars set MLFLOW_EXPERIMENT_NAME="deepor"
conda env config vars set REGISTERED_MODEL_NAME="deepor"
```

## 📁 Project Structure

```
DeepSelective/
├── config/                 # Configuration files
│   ├── train.yaml         # Main training configuration
│   ├── nni_config.yaml    # Neural Network Intelligence config
│   └── search_space.json  # Hyperparameter search space
├── dataloader/            # Data loading utilities
│   └── mimic_dataloader.py
├── dataset/               # Dataset classes
│   └── mimic_dataset.py
├── models/                # Model implementations
│   ├── deepor.py         # Main DeepSelective model
│   ├── dgfs.py          # DGFS module
│   ├── ata.py           # ATA module
│   ├── rml.py           # RML module
│   └── baseline.py      # Baseline models
├── utils/                # Utility functions
│   ├── trainer.py       # Training logic
│   ├── evaluator.py     # Evaluation metrics
│   ├── loss.py          # Loss functions
│   └── metrics.py       # Performance metrics
├── logger/               # Logging utilities
│   ├── mlflow_logger.py
│   └── std_logger.py
├── benchmarks/           # Benchmark datasets
├── docs/                # Documentation
│   └── manuscript.tex   # Research paper
├── main.py             # Main training script
└── environment.yaml    # Conda environment file
```

## 🚀 Quick Start

### 1. Data Preparation

For MIMIC-III dataset preprocessing:

```bash
# Create conda environment for preprocessing
conda create -n mimic3 python=3.6
conda activate mimic3

# Install requirements
conda install numpy pandas=0.23.4 pyyaml tqdm

# Extract and process MIMIC-III data
python -m mimic3benchmark.scripts.extract_subjects ../mimic3-rawdata ../mimic3-extract
python -m mimic3benchmark.scripts.validate_events ./mimic3-extract
python -m mimic3benchmark.scripts.extract_episodes_from_subjects ./mimic3-extract
python -m mimic3benchmark.scripts.split_train_and_test ./mimic3-extract

# Create prediction tasks
python -m mimic3benchmark.scripts.create_in_hospital_mortality ./mimic3-extract ./in-hospital-mortality/
python -m mimic3benchmark.scripts.create_decompensation ./mimic3-extract ./decompensation/
```

### 2. Training

```bash
# Basic training
python main.py

# Training with custom configuration
python main.py model.deepor.temperature=1.5 model.deepor.latent_dim=256

# Enable MLFlow logging
python main.py logger.mlflow=true

# Debug mode (2 epochs)
python main.py other.debug=true
```

### 3. Configuration

Key configuration parameters in `config/train.yaml`:

```yaml
model:
  deepor:
    input_dim: 76              # Input feature dimension
    time_steps: 48             # Time sequence length
    latent_dim: 128            # Latent representation dimension
    temperature: 2.0           # Gumbel-Softmax temperature
    setpoint: 0.1             # Sparsity target
    Kp: 0.001                 # PID controller parameters
    Ki: 0.001
    Kd: 0.001

train:
  batch_size: 128
  learning_rate: 1e-4
  num_epoch: 50
  default_metric: "auroc"
```

## 🔬 Interpretability Features

### Feature Importance Analysis
DeepSelective provides multiple interpretability mechanisms:

1. **Mutual Information Analysis**: Quantifies information retention between input and compressed features
2. **Statistical Significance Testing**: T-tests validate the importance of selected features
3. **Clinical Feature Mapping**: Maps selected features to clinical significance

### Example: Selected Features for Mortality Prediction
- Respiratory rate: Patient breathing status
- GCS Verbal Response: Cognitive function assessment
- GCS Eye Opening: Consciousness level
- Glucose levels: Metabolic status indicator

## 📈 Evaluation Metrics

The framework supports comprehensive evaluation using:

- **AUROC**: Area Under the Receiver Operating Characteristic curve
- **AUPRC**: Area Under the Precision-Recall curve
- **F1-Score**: Harmonic mean of precision and recall
- **min(Se, P+)**: Minimum of sensitivity and precision
- **Mutual Information**: Feature informativeness measure

## 🧪 Ablation Studies

Comprehensive ablation studies validate each component:

| Variant | Description | Performance Drop |
|---------|-------------|-----------------|
| DeepSelective_na | Remove ATA module | -5.3% AUROC |
| DeepSelective_nd | Remove DGFS module | -4.2% AUROC |
| DeepSelective_nc | Remove dynamic sparsity | -2.1% AUROC |
| DeepSelective_nr | Remove RML module | -2.0% AUROC |

## 📚 Citation

If you use DeepSelective in your research, please cite our paper:

```bibtex
@article{zhang2025deepselective,
  title={DeepSelective: Interpretable Prognosis Prediction via Feature Selection and Compression in EHR Data},
  author={Zhang, Ruochi and Yang, Qian and Wang, Xiaoyang and Wang, Tian and Zhou, Qiong and Deng, Ziqi and Li, Kewei and Wang, Yueying and Fan, Yusi and Zhang, Jiale and Huang, Lan and Liu, Chang and Zhou, Fengfeng},
  journal={Pattern Recognition},
  year={2025}
}
```

## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.


## 📞 Contact

- **Ruochi Zhang**: zrc720@gmail.com
- **Fengfeng Zhou**: FengfengZhou@gmail.com
- **Project Website**: [http://www.healthinformaticslab.org/supp/resources.php](http://www.healthinformaticslab.org/supp/resources.php)

## 🔗 Related Resources

- [MIMIC-III Dataset](https://mimic.physionet.org/)
- [MIMIC-IV Dataset](https://mimic.mit.edu/)
- [MIMIC-III Benchmark](https://github.com/YerevaNN/mimic3-benchmarks)
- [PyTorch Documentation](https://pytorch.org/docs/)

---

**Note**: This framework is designed for research purposes. For clinical applications, please ensure proper validation and regulatory compliance.