Merge pull request #1 from michiyasunaga/main · snap-stanford/GreaseLM@a194c0d · GitHub
Skip to content

Commit a194c0d

Browse files
authored
Merge pull request #1 from michiyasunaga/main
add medqa
2 parents 079acd4 + 596aa50 commit a194c0d

5 files changed

Lines changed: 718 additions & 28 deletions

File tree

README.md

Lines changed: 12 additions & 2 deletions

greaselm.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DECODER_DEFAULT_LR = {
2727
'csqa': 1e-3,
2828
'obqa': 3e-4,
29+
'medqa_usmle': 1e-3,
2930
}
3031

3132
import numpy as np
@@ -81,8 +82,11 @@ def construct_model(args, kg):
8182
##########################################################
8283

8384
if kg == "cpnet":
84-
n_ntype = 4
85+
n_ntype = 4
8586
n_etype = 38
87+
elif kg == "ddb":
88+
n_ntype = 4
89+
n_etype = 34
8690
else:
8791
raise ValueError("Invalid KG.")
8892
if args.cxt_node_connects_all:
@@ -178,7 +182,7 @@ def calc_eval_accuracy(eval_set, model, loss_type, loss_func, debug, save_test_p
178182

179183
def train(args, resume, has_test_split, devices, kg):
180184
print("args: {}".format(args))
181-
185+
182186
if resume:
183187
args.save_dir = os.path.dirname(args.resume_checkpoint)
184188
if not args.debug:
@@ -211,7 +215,14 @@ def train(args, resume, has_test_split, devices, kg):
211215

212216
# Get the names of the loaded LM parameters
213217
loading_info = model.lmgnn.loading_info
214-
loaded_roberta_keys = [k.replace("roberta.", "lmgnn.mp.") for k in loading_info["all_keys"]]
218+
# loaded_roberta_keys = [k.replace("roberta.", "lmgnn.mp.") for k in loading_info["all_keys"]]
219+
def _rename_key(key):
220+
if key.startswith("roberta."):
221+
return key.replace("roberta.", "lmgnn.mp.")
222+
else:
223+
return "lmgnn.mp." + key
224+
225+
loaded_roberta_keys = [_rename_key(k) for k in loading_info["all_keys"]]
215226

216227
# Separate the parameters into loaded and not loaded
217228
loaded_params, not_loaded_params, params_to_freeze, small_lr_params, large_lr_params = sep_params(model, loaded_roberta_keys)
@@ -316,7 +327,7 @@ def train(args, resume, has_test_split, devices, kg):
316327
model.train()
317328

318329
for qids, labels, *input_data in tqdm(train_dataloader, desc="Batch"):
319-
# labels: [bs]
330+
# labels: [bs]
320331
start_time = time.time()
321332
optimizer.zero_grad()
322333
bs = labels.size(0)
@@ -387,11 +398,11 @@ def train(args, resume, has_test_split, devices, kg):
387398
if not args.debug:
388399
with open(log_path, 'a') as fout:
389400
fout.write('{:3},{:5},{:7.4f},{:7.4f},{:7.4f},{:7.4f},{:3}\n'.format(epoch_id, global_step, dev_acc, test_acc, best_dev_acc, final_test_acc, best_dev_epoch))
390-
401+
391402
wandb.log({"dev_acc": dev_acc, "dev_loss": dev_total_loss, "best_dev_acc": best_dev_acc, "best_dev_epoch": best_dev_epoch}, step=global_step)
392403
if has_test_split:
393404
wandb.log({"test_acc": test_acc, "test_loss": test_total_loss, "final_test_acc": final_test_acc}, step=global_step)
394-
405+
395406
# Save the model checkpoint
396407
if args.save_model:
397408
model_state_dict = model.state_dict()
@@ -500,10 +511,12 @@ def main(args):
500511
logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(name)s:%(funcName)s():%(lineno)d] %(message)s',
501512
datefmt='%m/%d/%Y %H:%M:%S',
502513
level=logging.WARNING)
503-
514+
504515
has_test_split = True
505516
devices = get_devices(args.cuda)
506517
kg = "cpnet"
518+
if args.dataset == "medqa_usmle":
519+
kg = "ddb"
507520

508521
if not args.use_wandb:
509522
wandb_mode = "disabled"
@@ -518,7 +531,7 @@ def main(args):
518531
args.wandb_id = wandb_id
519532

520533
args.hf_version = transformers.__version__
521-
534+
522535
with wandb.init(project="KG-LM", config=args, name=args.run_name, resume="allow", id=wandb_id, settings=wandb.Settings(start_method="fork"), mode=wandb_mode):
523536
print(socket.gethostname())
524537
print ("pid:", os.getpid())
@@ -537,7 +550,7 @@ def main(args):
537550

538551
if __name__ == '__main__':
539552
__spec__ = None
540-
553+
541554
parser = parser_utils.get_parser()
542555
args, _ = parser.parse_known_args()
543556

@@ -590,4 +603,4 @@ def main(args):
590603
parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution')
591604

592605
args = parser.parse_args()
593-
main(args)
606+
main(args)

modeling/modeling_greaselm.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26+
if os.environ.get('INHERIT_BERT', 0):
27+
ModelClass = modeling_bert.BertModel
28+
else:
29+
ModelClass = modeling_roberta.RobertaModel
30+
31+
print ('ModelClass', ModelClass)
32+
33+
2634
class GreaseLM(nn.Module):
2735

2836
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
@@ -31,11 +39,11 @@ def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=
3139
pretrained_concept_emb=None, freeze_ent_emb=True,
3240
init_range=0.02, ie_dim=200, info_exchange=True, ie_layer_num=1, sep_ie_layers=False, layer_id=-1):
3341
super().__init__()
34-
self.lmgnn = LMGNN(args, model_name, k, n_ntype, n_etype,
42+
self.lmgnn = LMGNN(args, model_name, k, n_ntype, n_etype,
3543
n_concept, concept_dim, concept_in_dim, n_attention_head,
3644
fc_dim, n_fc_layer, p_emb, p_gnn, p_fc, pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb,
3745
init_range=init_range, ie_dim=ie_dim, info_exchange=info_exchange, ie_layer_num=ie_layer_num, sep_ie_layers=sep_ie_layers, layer_id=layer_id)
38-
46+
3947
def batch_graph(self, edge_index_init, edge_type_init, n_nodes):
4048
"""
4149
edge_index_init: list of (n_examples, ). each entry is torch.tensor(2, E)
@@ -59,7 +67,7 @@ def forward(self, *inputs, cache_output=False, detail=False):
5967
-> (2, total E)
6068
edge_type: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(E(variable), )
6169
-> (total E, )
62-
70+
6371
returns:
6472
logits: [bs, nc]
6573
"""
@@ -85,7 +93,7 @@ def forward(self, *inputs, cache_output=False, detail=False):
8593
return logits, attn, concept_ids.view(bs, nc, -1), node_type_ids.view(bs, nc, -1), edge_index_orig, edge_type_orig
8694
# edge_index_orig: list of (batch_size, num_choice). each entry is torch.tensor(2, E)
8795
# edge_type_orig: list of (batch_size, num_choice). each entry is torch.tensor(E, )
88-
96+
8997
def get_fake_inputs(self, device="cuda:0"):
9098
bs = 4
9199
nc = 5
@@ -129,14 +137,14 @@ def test_GreaseLM(device):
129137

130138
class LMGNN(nn.Module):
131139

132-
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
140+
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
133141
n_concept=799273, concept_dim=200, concept_in_dim=1024, n_attention_head=2,
134142
fc_dim=200, n_fc_layer=0, p_emb=0.2, p_gnn=0.2, p_fc=0.2,
135143
pretrained_concept_emb=None, freeze_ent_emb=True,
136144
init_range=0.02, ie_dim=200, info_exchange=True, ie_layer_num=1, sep_ie_layers=False, layer_id=-1):
137145
super().__init__()
138-
config, _ = modeling_roberta.RobertaModel.config_class.from_pretrained(
139-
model_name,
146+
config, _ = ModelClass.config_class.from_pretrained(
147+
model_name,
140148
cache_dir=None, return_unused_kwargs=True,
141149
force_download=False,
142150
output_hidden_states=True
@@ -281,11 +289,12 @@ def test_LMGNN(device):
281289
model.check_outputs(*outputs)
282290

283291

284-
class TextKGMessagePassing(modeling_roberta.RobertaModel):
292+
293+
class TextKGMessagePassing(ModelClass):
285294

286295
def __init__(self, config, args={}, k=5, n_ntype=4, n_etype=38, dropout=0.2, concept_dim=200, ie_dim=200, p_fc=0.2, info_exchange=True, ie_layer_num=1, sep_ie_layers=False):
287296
super().__init__(config=config)
288-
297+
289298
self.n_ntype = n_ntype
290299
self.n_etype = n_etype
291300

@@ -633,7 +642,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
633642
state_dict = state_dict.copy()
634643
if metadata is not None:
635644
state_dict._metadata = metadata
636-
645+
637646
all_keys = list(state_dict.keys())
638647

639648
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
@@ -756,7 +765,7 @@ class RoBERTaGAT(modeling_bert.BertEncoder):
756765

757766
def __init__(self, config, k=5, n_ntype=4, n_etype=38, hidden_size=200, dropout=0.2, concept_dim=200, ie_dim=200, p_fc=0.2, info_exchange=True, ie_layer_num=1, sep_ie_layers=False):
758767
super().__init__(config)
759-
768+
760769
self.k = k
761770
self.edge_encoder = torch.nn.Sequential(torch.nn.Linear(n_etype + 1 + n_ntype * 2, hidden_size), torch.nn.BatchNorm1d(hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size))
762771
self.gnn_layers = nn.ModuleList([modeling_gnn.GATConvE(hidden_size, n_ntype, n_etype, self.edge_encoder) for _ in range(k)])
@@ -799,14 +808,14 @@ def forward(self, hidden_states, attention_mask, special_tokens_mask, head_mask,
799808

800809
if output_attentions:
801810
all_attentions = all_attentions + (layer_outputs[1],)
802-
811+
803812
if i >= self.num_hidden_layers - self.k:
804813
# GNN
805814
gnn_layer_index = i - self.num_hidden_layers + self.k
806815
_X = self.gnn_layers[gnn_layer_index](_X, edge_index, edge_type, _node_type, _node_feature_extra)
807816
_X = self.activation(_X)
808817
_X = F.dropout(_X, self.dropout_rate, training = self.training)
809-
818+
810819
# Exchange info between LM and GNN hidden states (Modality interaction)
811820
if self.info_exchange == True or (self.info_exchange == "every-other-layer" and (i - self.num_hidden_layers + self.k) % 2 == 0):
812821
X = _X.view(bs, -1, _X.size(1)) # [bs, max_num_nodes, node_dim]
@@ -861,7 +870,7 @@ def check_outputs(self, outputs, _X):
861870

862871
def test_RoBERTaGAT(device):
863872
config, _ = modeling_roberta.RobertaModel.config_class.from_pretrained(
864-
"roberta-large",
873+
"roberta-large",
865874
cache_dir=None, return_unused_kwargs=True,
866875
force_download=False,
867876
output_hidden_states=True
@@ -880,11 +889,11 @@ def test_RoBERTaGAT(device):
880889
utils.print_cuda_info()
881890
free_gpus = utils.select_free_gpus()
882891
device = torch.device("cuda:{}".format(free_gpus[0]))
883-
892+
884893
# test_RoBERTaGAT(device)
885894

886895
# test_TextKGMessagePassing(device)
887896

888897
# test_LMGNN(device)
889898

890-
test_GreaseLM(device)
899+
test_GreaseLM(device)

run_greaselm__medqa_usmle.sh

Lines changed: 69 additions & 0 deletions

0 commit comments

Comments
 (0)