GitHub - DurraniBit/ProtoAlignNet: This repository provides the implementation of ProtoAlignNet, a prototype-guided deep learning framework for robust motor imagery electroencephalography (MI-EEG) classification. · GitHub
Skip to content

DurraniBit/ProtoAlignNet

Repository files navigation

ProtoAlignNet: Prototype-Aligned Patch Learning for Robust Motor Imagery EEG Classification

This repository provides the implementation of ProtoAlignNet, a prototype-guided deep learning framework for robust motor imagery electroencephalography (MI-EEG) classification.

ProtoAlignNet learns discriminative EEG patch representations by combining local-to-global feature extraction, class-wise prototype learning, and local patch-to-prototype alignment. Instead of relying only on a conventional fully connected classifier, the model compares local EEG patch tokens with learned class prototypes and regularizes their alignment through a prototype-consistency objective.

Main Components

1. Local-to-Global EEG Patch Encoder

The encoder first extracts local spatiotemporal EEG patterns using compact convolutional layers and then models global dependencies between patch tokens using Transformer encoder blocks. This design helps capture both short-range EEG dynamics and long-range relationships across temporal patches.

2. Class-Wise Prototype Learning

For each class, ProtoAlignNet learns multiple prototypes from the distribution of patch-token embeddings. These prototypes act as class-specific representative patterns. Multiple prototypes are used because EEG signals from the same MI class may contain different local activation patterns across trials, subjects, and time segments.

3. Prototype-Guided Classification

Each patch token is compared with the learned prototype bank. The model keeps the strongest local evidence for each prototype and aggregates prototype-level responses to produce the final class prediction. This makes the decision process more interpretable than a standard dense classifier.

4. Local Patch-to-Prototype Consistency Loss

ProtoAlignNet introduces a local patch-to-prototype consistency loss to align each informative EEG patch with its correct class prototype while separating it from competing prototypes. The final training objective is:

L = L_CE + lambda_LPPC L_LPPC

where L_CE is the cross-entropy loss and L_LPPC is the local patch-to-prototype consistency loss.

Key Features

  • Compact CNN-Transformer encoder for MI-EEG patch representation learning.
  • Multiple class-wise prototypes to model intra-class EEG variability.
  • Sinkhorn-balanced assignment to encourage diverse prototype utilization.
  • Prototype-guided prediction based on local EEG patch evidence.
  • LPPC loss for token-level prototype alignment and improved feature separation.

ProtoAlignNet: has been validated on BCI Competition IV-2a, IV-2b, and High Gamma datasets, achieving superior classification accuracy in both subject-dependent and cross-subject settings.

Our research builds upon and improves the implementations from the Proto Non-Parametric and EEG Conformer. We sincerely thank the authors for making these open-source projects publicly available.

Overall Framework:

Overall Framework

ProtoAlignNet learns discriminative EEG patch representations by combining local-to-global feature extraction, class-wise prototype learning, prototype-guided classification, and local patch-to-prototype alignment.

Requirements

The code is implemented in Python and PyTorch. Main dependencies may include:

Python 3.10
Pytorch 1.13.1
numpy
scipy
scikit-learn
mne
pyyaml
tqdm
matplotlib

Citation

If you use this code, please cite our paper:

Contact

For any inquiries or further information, feel free to contact us at: sirajdurani@gmail.com

About

This repository provides the implementation of ProtoAlignNet, a prototype-guided deep learning framework for robust motor imagery electroencephalography (MI-EEG) classification.

Resources

License

Stars

Watchers

Forks

Packages

Contributors

Languages