Data Augmentation : SMILES by aditya0by0 · Pull Request #115 · ChEB-AI/python-chebai · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions README.md
39 changes: 34 additions & 5 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import lightning as pl
Expand Down Expand Up @@ -416,10 +417,17 @@ def prepare_data(self, *args, **kwargs) -> None:

self._prepare_data_flag += 1
self._perform_data_preparation(*args, **kwargs)
self._after_prepare_data(*args, **kwargs)

def _perform_data_preparation(self, *args, **kwargs) -> None:
raise NotImplementedError

def _after_prepare_data(self, *args, **kwargs) -> None:
"""
Hook to perform additional pre-processing after pre-processed data is available.
"""
...

def setup(self, *args, **kwargs) -> None:
"""
Setup the data module.
Expand Down Expand Up @@ -461,14 +469,17 @@ def _set_processed_data_props(self):
- self._num_of_labels: Number of target labels in the dataset.
- self._feature_vector_size: Maximum feature vector length across all data points.
"""
data_pt = torch.load(
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
weights_only=False,
pt_file_path = os.path.join(
self.processed_dir, self.processed_file_names_dict["data"]
)
data_pt = torch.load(pt_file_path, weights_only=False)

self._num_of_labels = len(data_pt[0]["labels"])
self._feature_vector_size = max(len(d["features"]) for d in data_pt)

print(
f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples"
)
print(f"Number of labels for loaded data: {self._num_of_labels}")
print(f"Feature vector size: {self._feature_vector_size}")

Expand Down Expand Up @@ -731,6 +742,7 @@ def __init__(
self.splits_file_path = self._validate_splits_file_path(
kwargs.get("splits_file_path", None)
)
self._data_pkl_filename: str = "data.pkl"

@staticmethod
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -869,6 +881,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
"""
pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb"))

def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]:
"""
Gets the processed dataset pickle file.

Args:
filename (str): The filename for the pickle file.

Returns:
pd.DataFrame: The processed dataset as a DataFrame.
"""
file_path = Path(self.processed_dir_main) / filename
if file_path.exists():
return pd.read_pickle(file_path)
return None

# ------------------------------ Phase: Setup data -----------------------------------
def setup_processed(self) -> None:
"""
Expand Down Expand Up @@ -907,7 +934,9 @@ def _get_data_size(input_file_path: str) -> int:
int: The size of the data.
"""
with open(input_file_path, "rb") as f:
return len(pd.read_pickle(f))
df = pd.read_pickle(f)
print(f"Processed data size ({input_file_path}): {len(df)} rows")
return len(df)

@abstractmethod
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
Expand Down Expand Up @@ -1223,7 +1252,7 @@ def processed_main_file_names_dict(self) -> dict:
dict: A dictionary mapping dataset key to their respective file names.
For example, {"data": "data.pkl"}.
"""
return {"data": "data.pkl"}
return {"data": self._data_pkl_filename}

@property
def raw_file_names(self) -> List[str]:
Expand Down
130 changes: 127 additions & 3 deletions chebai/preprocessing/datasets/chebi.py