add pretraining by sfluegel05 · Pull Request #1 · ChEB-AI/python-chebai-graph · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added chebai_graph/loss/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions chebai_graph/loss/pretraining.py
144 changes: 120 additions & 24 deletions chebai_graph/models/graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import typing

from torch import nn
from torch_geometric import nn as tgnn
from torch_scatter import scatter_add, scatter_mean
import torch
import torch.nn.functional as F
from chebai.models.base import ChebaiBaseNet
from chebai.preprocessing.structures import XYData
from torch import nn
from torch_geometric import nn as tgnn
from torch_geometric.data import Data as GraphData
from torch_scatter import scatter_add, scatter_mean

from chebai.models.base import ChebaiBaseNet
from chebai_graph.loss.pretraining import MaskPretrainingLoss

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)

Expand All @@ -17,6 +19,9 @@ class GraphBaseNet(ChebaiBaseNet):
def _get_prediction_and_labels(self, data, labels, output):
return torch.sigmoid(output), labels.int()

def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor:
return batch.y.float() if batch.y is not None else None


class JCIGraphNet(GraphBaseNet):
NAME = "GNN"
Expand Down Expand Up @@ -72,10 +77,10 @@ def forward(self, batch):
return a


class ResGatedGraphConvNet(GraphBaseNet):
class ResGatedGraphConvNetBase(GraphBaseNet):
"""GNN that supports edge attributes"""

NAME = "ResGatedGraphConvNet"
NAME = "ResGatedGraphConvNetBase"

def __init__(self, config: typing.Dict, **kwargs):
super().__init__(**kwargs)
Expand All @@ -88,7 +93,9 @@ def __init__(self, config: typing.Dict, **kwargs):
config["n_linear_layers"] if "n_linear_layers" in config else 3
)
self.n_atom_properties = int(config["n_atom_properties"])
self.n_bond_properties = int(config["n_bond_properties"])
self.n_bond_properties = (
int(config["n_bond_properties"]) if "n_bond_properties" in config else 7
)
self.n_molecule_properties = (
int(config["n_molecule_properties"])
if "n_molecule_properties" in config
Expand Down Expand Up @@ -118,21 +125,6 @@ def __init__(self, config: typing.Dict, **kwargs):
self.in_length, self.hidden_length, edge_dim=self.n_bond_properties
)

self.linear_layers = torch.nn.ModuleList([])
for i in range(self.n_linear_layers - 1):
if i == 0:
self.linear_layers.append(
nn.Linear(
self.hidden_length + self.n_molecule_properties,
self.hidden_length,
)
)
else:
self.linear_layers.append(
nn.Linear(self.hidden_length, self.hidden_length)
)
self.final_layer = nn.Linear(self.hidden_length, self.out_dim)

def forward(self, batch):
graph_data = batch["features"][0]
assert isinstance(graph_data, GraphData)
Expand All @@ -149,17 +141,121 @@ def forward(self, batch):
a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr
)
)
a = self.dropout(a)
return a


class ResGatedGraphConvNetGraphPred(GraphBaseNet):
"""GNN for graph-level prediction"""

NAME = "ResGatedGraphConvNetPred"

def __init__(
self,
config: typing.Dict,
n_linear_layers=2,
pretrained_checkpoint=None,
**kwargs,
):
super().__init__(**kwargs)
if pretrained_checkpoint:
self.gnn = ResGatedGraphConvNetPretrain.load_from_checkpoint(
pretrained_checkpoint, map_location=self.device
).as_pretrained
else:
self.gnn = ResGatedGraphConvNetBase(config, **kwargs)
self.linear_layers = torch.nn.ModuleList(
[
torch.nn.Linear(self.gnn.hidden_length, self.gnn.hidden_length)
for _ in range(n_linear_layers - 1)
]
)
self.final_layer = nn.Linear(self.gnn.hidden_length, self.out_dim)

def forward(self, batch):
graph_data = batch["features"][0]
assert isinstance(graph_data, GraphData)
a = self.gnn(batch)
a = scatter_add(a, graph_data.batch, dim=0)

a = torch.cat([a, graph_data.molecule_attr], dim=1)

for lin in self.linear_layers:
a = self.activation(lin(a))
a = self.gnn.activation(lin(a))
a = self.final_layer(a)
return a


class ResGatedGraphConvNetPretrain(GraphBaseNet):
"""For pretraining. BaseNet with an additional output layer for predicting atom properties"""

NAME = "ResGatedGraphConvNetPre"

def __init__(self, config: typing.Dict, **kwargs):
if "criterion" not in kwargs or kwargs["criterion"] is None:
kwargs["criterion"] = MaskPretrainingLoss()
print(f"Initing ResGatedGraphConvNetPre with criterion: {kwargs['criterion']}")
super().__init__(**kwargs)
self.gnn = ResGatedGraphConvNetBase(config, **kwargs)
self.atom_prediction = nn.Linear(
self.gnn.hidden_length, self.gnn.n_atom_properties
)

def forward(self, batch):
data = batch["features"][0]
embedding = self.gnn(batch)
node_rep = embedding[data.masked_atom_indices.int()]
atom_pred = torch.gather(
self.atom_prediction(node_rep),
1,
data.masked_property_indices.to(torch.int64),
)
return atom_pred

@property
def as_pretrained(self):
return self.gnn

def _process_labels_in_batch(self, batch):
return batch.x[0].mask_node_label


class ResGatedGraphConvNetPretrainBonds(GraphBaseNet):
"""For pretraining. BaseNet with two output layers for predicting atom and bond properties"""

NAME = "ResGatedGraphConvNetPreBonds"

def __init__(self, config: typing.Dict, **kwargs):
if "criterion" not in kwargs or kwargs["criterion"] is None:
kwargs["criterion"] = MaskPretrainingLoss()
print(f"Initing ResGatedGraphConvNetPre with criterion: {kwargs['criterion']}")
super().__init__(config, **kwargs)
self.bond_prediction = nn.Linear(
self.gnn.hidden_length, self.gnn.n_bond_properties
)

def forward(self, batch):
data = batch["features"][0]
embedding = self.gnn(batch)
node_rep = embedding[data.masked_atom_indices.int()]
atom_pred_all_properties = self.atom_prediction(node_rep)
atom_pred = torch.gather(
atom_pred_all_properties, 1, data.masked_property_indices.to(torch.int64)
)

masked_edge_index = data.edge_index[:, data.connected_edge_indices.int()].int()
edge_rep = embedding[masked_edge_index[0]] + embedding[masked_edge_index[1]]
bond_pred = self.bond_prediction(edge_rep)
return atom_pred, bond_pred

def _get_prediction_and_labels(self, data, labels, output):
if isinstance(labels, tuple):
labels = tuple(label.int() for label in labels)
return tuple(torch.sigmoid(out) for out in output), labels

def _process_labels_in_batch(self, batch):
return batch.x[0].mask_node_label, batch.x[0].mask_edge_label


class JCIGraphAttentionNet(GraphBaseNet):
NAME = "AGNN"

Expand Down
Empty file.
13 changes: 13 additions & 0 deletions chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-1
1
0
2
3
-4
-2
4
-3
5
6
7
-5
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
UNSPECIFIED
SP3
S
SP3D2
SP3D
SP
SP2
7 changes: 7 additions & 0 deletions chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
0
3
2
4
1
5
6
Empty file.
119 changes: 119 additions & 0 deletions chebai_graph/preprocessing/bin/AtomType/indices_one_hot.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
0
13
9
11
7
47
6
4
19
50
51
17
12
8
20
27
38
81
3
30
29
39
55
25
56
70
88
40
37
26
82
78
80
16
33
1
35
15
5
44
53
85
79
24
14
87
32
2
48
63
76
92
23
64
57
34
22
28
74
42
46
52
83
62
49
58
71
65
67
77
31
59
75
45
36
54
18
10
86
43
84
118
90
73
113
114
115
116
117
112
72
104
41
105
106
107
108
109
110
111
21
103
60
61
66
68
69
89
91
93
94
95
96
97
98
99
100
101
102
Empty file.
5 changes: 5 additions & 0 deletions chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt
Empty file.
Empty file.
Loading