morpho/psychnet at main · Jah-yee/morpho · GitHub
Skip to content

Latest commit

 

History

History

README.md

PsychNet

Multi-branch deep learning for psychiatric epidemiology — 基于广西精神分裂症住院数据 (32,459 records, 2014–2023) 的多分支深度学习架构。


Architecture

  Patient Record (17 features)
         │
    ┌────┴────┐
    ▼         ▼
DeepInsight   Feature
 Portrait    Tokenizer
 (32×32)    (17 tokens)
    │         │
    ▼         ▼
ViT-Clinical  FT-Transformer        ~15.9M params
 (384-dim)    (256-dim)              RTX 4070 Ti Super ✓
    │         │
    └────┬────┘
         ▼
  Cross-Modal Fusion
  (Bidirectional Attention + Gating)
         │
         ▼
    Gated Output (512-dim)
         │
    ┌────┼────┬────┬────┐
    ▼    ▼    ▼    ▼    ▼
  再入院 费用  住院  灾难性 嵌入向量
  预测  预测  天数  支出   (表型)

Quick Start

# 1. Setup
python setup.py

# 2. Quick test (~1 min)
python scripts/train.py --quick

# 3. Full training (~1-2 hours on RTX 4070 Ti Super)
python scripts/train.py --config configs/default.yaml

File Structure

psychnet/
├── data/                          数据管线 (1,308 行)
│   ├── preprocessing.py           CSV → Tensor 完整管线
│   ├── deepinsight.py             DeepInsight 临床肖像生成
│   └── dataset.py                 PyTorch Dataset + DataLoader
│
├── models/                        模型架构 (2,375 行)
│   ├── vision_branch.py           ViT-Clinical (6L/6H/384d)
│   ├── tabular_branch.py          FT-Transformer (4L/8H/256d)
│   ├── fusion.py                  双向交叉注意力 + 门控融合
│   ├── heads.py                   5个多任务预测头
│   └── psychnet.py                主模型组装
│
├── training/                      训练引擎
│   ├── losses.py                  Focal Loss + 不确定性加权
│   └── trainer.py                 混合精度 + 早停 + WandB
│
├── evaluation/                    评估与可解释性
│   ├── metrics.py                 AUC/F1/R²/Bootstrap CI
│   ├── interpretability.py        Grad-CAM + 注意力分析
│   └── visualization.py           10种出版级图表
│
├── scripts/
│   ├── train.py                   ★ 主入口:9阶段训练管线
│   └── download_external_data.py  外部数据下载
│
├── configs/default.yaml           超参数配置
├── external_data/                 GBD / NBS / WHO 参考数据
├── requirements.txt               Python 依赖
├── setup.py                       ★ 一键环境检测与验证
│
├── ARCHITECTURE.md                技术架构文档 (含论文 Figure 素材)
├── GUIDE.md                       完整操作指南 (数据→训练→评估→可视化)
├── CURSOR_SETUP.md                Cursor 专用部署指南
└── RESEARCH_PROPOSAL.md           研究规划文档

Total: 6,876 lines across 19 Python files


Documents

文档 用途 读者
ARCHITECTURE.md 模型架构详解 + 论文 Figure 1 素材 你 / 论文审稿人
GUIDE.md 数据管线→训练→评估→可视化全流程 你 / Cursor
CURSOR_SETUP.md 环境配置 + 跨平台部署 Cursor / 合作者
RESEARCH_PROPOSAL.md 研究规划 + 论文拆分策略 + 时间线 你 / PI

Training Outputs

训练完成后自动产出以下文件:

类型 文件 用途
过程记录 training_history.json 每 epoch loss/metrics,绘制学习曲线
过程记录 config.json 可复现的配置快照
模型权重 checkpoints/best_model.pt 验证集最佳模型(论文报告指标用这个)
模型权重 checkpoints/final_model.pt 最终 epoch 权重
评估指标 test_results.json AUC-ROC, AUC-PR, F1, R², RMSE, Bootstrap CI
评估指标 summary.json 最佳 epoch, 最终指标, 运行时间
图表 figures/*.png 10张300 DPI出版级图表(Nature/Lancet格式)
嵌入向量 embeddings/patient_embeddings.npy N×512 嵌入,用于聚类/表型发现
嵌入向量 embeddings/cluster_labels.npy HDBSCAN 聚类标签
可解释性 interpretability/feature_importance.json 17个特征的重要度排序
可解释性 interpretability/gate_weight_stats.json Vision vs Tabular 融合权重

Model Specs

组件 参数量 说明
Vision Branch (ViT-Clinical) 11,895,424 6层 Transformer, 384维, 处理32×32临床肖像
Tabular Branch (FT-Transformer) 4,270,336 4层 Transformer, 256维, 处理17维患者特征
Cross-Modal Fusion 1,480,448 双向交叉注意力 + 门控融合
Multi-Task Heads 526,340 5个预测头 (再入院/费用/天数/灾难性/嵌入)
Total 15,871,620 训练峰值 ~2GB 显存

Requirements

  • Python ≥ 3.8
  • PyTorch ≥ 2.0 (CUDA recommended)
  • NVIDIA GPU with ≥ 8GB VRAM (RTX 3060+ recommended)
  • 完整依赖见 requirements.txt