Test for CLI + project dependencies changes by aditya0by0 · Pull Request #105 · ChEB-AI/python-chebai · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/test.yml
3 changes: 1 addition & 2 deletions chebai/models/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class FFN(ChebaiBaseNet):

def __init__(
self,
input_size: int,
hidden_layers: List[int] = [
1024,
],
Expand All @@ -20,7 +19,7 @@ def __init__(
super().__init__(**kwargs)

layers = []
current_layer_input_size = input_size
current_layer_input_size = self.input_dim
for hidden_dim in hidden_layers:
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
layers.append(Residual(MLPBlock(hidden_dim, hidden_dim)))
Expand Down
1 change: 0 additions & 1 deletion configs/model/ffn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ class_path: chebai.models.ffn.FFN
init_args:
optimizer_kwargs:
lr: 1e-3
input_size: 2560
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ maintainers = [
]
readme = "README.md"
license = { text = "AGPL-3.0" }
requires-python = ">=3.9,<3.13"
requires-python = ">=3.10,<3.13"
dependencies = [
"networkx",
"numpy",
Expand All @@ -28,7 +28,7 @@ dependencies = [
"pysmiles==1.1.2",
"rdkit",
"selfies",
"lightning>=2.5",
"lightning<=2.5.1",
"jsonargparse[signatures]>=4.17",
"omegaconf",
"deepsmiles",
Expand Down
Empty file added tests/unit/cli/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions tests/unit/cli/mock_dm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from lightning.pytorch.core.datamodule import LightningDataModule
from torch.utils.data import DataLoader

from chebai.preprocessing.collate import RaggedCollator


class MyLightningDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self._num_of_labels = None
self._feature_vector_size = None
self.collator = RaggedCollator()

def prepare_data(self):
pass

def setup(self, stage=None):
self._num_of_labels = 10
self._feature_vector_size = 20
print(f"Number of labels: {self._num_of_labels}")
print(f"Number of features: {self._feature_vector_size}")

@property
def num_of_labels(self):
return self._num_of_labels

@property
def feature_vector_size(self):
return self._feature_vector_size

def train_dataloader(self):
assert self.feature_vector_size is not None, "feature_vector_size must be set"
# Dummy dataset for example purposes

datalist = [
{
"features": torch.randn(self._feature_vector_size),
"labels": torch.randint(0, 2, (self._num_of_labels,), dtype=torch.bool),
"ident": i,
"group": None,
}
for i in range(100)
]

return DataLoader(datalist, batch_size=32, collate_fn=self.collator)
1 change: 1 addition & 0 deletions tests/unit/cli/mock_dm_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: tests.unit.cli.mock_dm.MyLightningDataModule
35 changes: 35 additions & 0 deletions tests/unit/cli/testCLI.py