[Feature]: Generalize Prediction pipeline for Lightning CLI models#148
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a new generalized prediction pipeline intended to work with LightningCLI-saved models/checkpoints, including persisting classification label names into checkpoints for consistent prediction output formatting.
Changes:
- Add checkpoint persistence of
classification_labels(derived from a datasetclasses.txt) and wire the dataset path into model init via LightningCLI argument linking. - Introduce a new SMILES prediction entrypoint (
chebai/result/prediction.py) that reconstructs model/datamodule from checkpoint hyperparameters. - Refactor
XYBaseDataModule.predict_dataloaderto build a prediction dataloader from an in-memory SMILES list, plus update docs/tests and add VS Code workspace files.
Reviewed changes
Copilot reviewed 11 out of 12 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/cli/testCLI.py | Adjusts CLI unit test model args (smaller hidden layer). |
| tests/unit/cli/mock_dm.py | Adds classes_txt_file_path for CLI linking in tests. |
| tests/unit/cli/classification_labels.txt | Adds sample classification labels used by CLI unit tests. |
| chebai/trainer/CustomTrainer.py | Removes prior bespoke prediction logic and overrides predict(). |
| chebai/result/prediction.py | Adds new prediction script/class for SMILES/file inference from checkpoint. |
| chebai/preprocessing/datasets/base.py | Refactors prediction dataloader flow and adds classes_txt_file_path. |
| chebai/models/base.py | Adds label-file loading + saving classification_labels into checkpoints. |
| chebai/cli.py | Links data.classes_txt_file_path into model.init_args.classes_txt_file_path. |
| README.md | Updates prediction instructions to use the new prediction script. |
| .vscode/settings.json | Adds VS Code project settings (currently invalid JSON). |
| .vscode/extensions.json | Adds recommended VS Code extensions. |
| .gitignore | Stops ignoring the entire .vscode directory (only ignores launch.json). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 12 out of 13 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@copilot open a new pull request to apply changes based on the comments in this thread |
|
@aditya0by0 I've opened a new pull request, #152, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
) * Initial plan * Address review comments from PR #148 Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
|
It is possible to add test for prediction pipeline for electra. we can limit the vocab size and labels for such mock model-pipeline. @sfluegel05, Do you think is there need for such test case OR the existing test from this PR is sufficient. import os
import tempfile
import torch
from chebai.models.electra import Electra
# Smallest viable config
model = Electra(
model_type="classification",
config={
"vocab_size": 10,
"max_position_embeddings": 1,
"num_attention_heads": 1,
"num_hidden_layers": 1,
"type_vocab_size": 1,
"hidden_size": 1,
"intermediate_size": 1,
},
out_dim=10,
input_dim=10,
)
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params}")
# Save checkpoint and measure size
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_path = os.path.join(tmpdir, "electra_small.ckpt")
torch.save({"state_dict": model.state_dict()}, ckpt_path)
size_bytes = os.path.getsize(ckpt_path)
print(f"Checkpoint size: {size_bytes} bytes")(gnn) sh-4.4$ /home/staff/a/akhedekar/miniconda3/envs/gnn/bin/python /home/staff/a/akhedekar/python-chebai/test.py
Input dimension for the model: 10 Output dimension for the model: 10
Total parameters: 1959
Checkpoint size: 18367 bytes (0.018 MB) |

Generalize prediction logic
Please merge below PRs after this PR:
Related Discussion
Related bugs rectified in Lightning for the pipeline
LightningDataModule.load_from_checkpointdoes not restore subclass fromdatamodule_hyper_parametersLightning-AI/pytorch-lightning#21477save_hyperparameters(ignore=...)is not persistent across inheritance; ignored params reappear when base class also callssave_hyperparametersLightning-AI/pytorch-lightning#21488Additional changes
Save class labels in checkpoint under the key "classification_labels"
Wrap inference with
torch.inference_mode()to avoid gradient tracking (see Avoid gradient tracking python-chebifier#21)model.eval()in PyTorchtorch.no_grad()andtorch.inference_mode()Use
torch.compilefor faster inference