Multi-branch deep learning for psychiatric epidemiology — 基于广西精神分裂症住院数据 (32,459 records, 2014–2023) 的多分支深度学习架构。
Patient Record (17 features)
│
┌────┴────┐
▼ ▼
DeepInsight Feature
Portrait Tokenizer
(32×32) (17 tokens)
│ │
▼ ▼
ViT-Clinical FT-Transformer ~15.9M params
(384-dim) (256-dim) RTX 4070 Ti Super ✓
│ │
└────┬────┘
▼
Cross-Modal Fusion
(Bidirectional Attention + Gating)
│
▼
Gated Output (512-dim)
│
┌────┼────┬────┬────┐
▼ ▼ ▼ ▼ ▼
再入院 费用 住院 灾难性 嵌入向量
预测 预测 天数 支出 (表型)
# 1. Setup
python setup.py
# 2. Quick test (~1 min)
python scripts/train.py --quick
# 3. Full training (~1-2 hours on RTX 4070 Ti Super)
python scripts/train.py --config configs/default.yamlpsychnet/
├── data/ 数据管线 (1,308 行)
│ ├── preprocessing.py CSV → Tensor 完整管线
│ ├── deepinsight.py DeepInsight 临床肖像生成
│ └── dataset.py PyTorch Dataset + DataLoader
│
├── models/ 模型架构 (2,375 行)
│ ├── vision_branch.py ViT-Clinical (6L/6H/384d)
│ ├── tabular_branch.py FT-Transformer (4L/8H/256d)
│ ├── fusion.py 双向交叉注意力 + 门控融合
│ ├── heads.py 5个多任务预测头
│ └── psychnet.py 主模型组装
│
├── training/ 训练引擎
│ ├── losses.py Focal Loss + 不确定性加权
│ └── trainer.py 混合精度 + 早停 + WandB
│
├── evaluation/ 评估与可解释性
│ ├── metrics.py AUC/F1/R²/Bootstrap CI
│ ├── interpretability.py Grad-CAM + 注意力分析
│ └── visualization.py 10种出版级图表
│
├── scripts/
│ ├── train.py ★ 主入口:9阶段训练管线
│ └── download_external_data.py 外部数据下载
│
├── configs/default.yaml 超参数配置
├── external_data/ GBD / NBS / WHO 参考数据
├── requirements.txt Python 依赖
├── setup.py ★ 一键环境检测与验证
│
├── ARCHITECTURE.md 技术架构文档 (含论文 Figure 素材)
├── GUIDE.md 完整操作指南 (数据→训练→评估→可视化)
├── CURSOR_SETUP.md Cursor 专用部署指南
└── RESEARCH_PROPOSAL.md 研究规划文档
Total: 6,876 lines across 19 Python files
| 文档 | 用途 | 读者 |
|---|---|---|
| ARCHITECTURE.md | 模型架构详解 + 论文 Figure 1 素材 | 你 / 论文审稿人 |
| GUIDE.md | 数据管线→训练→评估→可视化全流程 | 你 / Cursor |
| CURSOR_SETUP.md | 环境配置 + 跨平台部署 | Cursor / 合作者 |
| RESEARCH_PROPOSAL.md | 研究规划 + 论文拆分策略 + 时间线 | 你 / PI |
训练完成后自动产出以下文件:
| 类型 | 文件 | 用途 |
|---|---|---|
| 过程记录 | training_history.json |
每 epoch loss/metrics,绘制学习曲线 |
| 过程记录 | config.json |
可复现的配置快照 |
| 模型权重 | checkpoints/best_model.pt |
验证集最佳模型(论文报告指标用这个) |
| 模型权重 | checkpoints/final_model.pt |
最终 epoch 权重 |
| 评估指标 | test_results.json |
AUC-ROC, AUC-PR, F1, R², RMSE, Bootstrap CI |
| 评估指标 | summary.json |
最佳 epoch, 最终指标, 运行时间 |
| 图表 | figures/*.png |
10张300 DPI出版级图表(Nature/Lancet格式) |
| 嵌入向量 | embeddings/patient_embeddings.npy |
N×512 嵌入,用于聚类/表型发现 |
| 嵌入向量 | embeddings/cluster_labels.npy |
HDBSCAN 聚类标签 |
| 可解释性 | interpretability/feature_importance.json |
17个特征的重要度排序 |
| 可解释性 | interpretability/gate_weight_stats.json |
Vision vs Tabular 融合权重 |
- Python ≥ 3.8
- PyTorch ≥ 2.0 (CUDA recommended)
- NVIDIA GPU with ≥ 8GB VRAM (RTX 3060+ recommended)
- 完整依赖见 requirements.txt
