We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 536d405 + b7da162 commit 5dfd6dcCopy full SHA for 5dfd6dc
1 file changed
util/loss_torch.py
@@ -32,15 +32,22 @@ def batch_softmax_loss(user_emb, item_emb, temperature):
32
return torch.mean(loss)
33
34
35
-def InfoNCE(view1, view2, temperature, b_cos = True):
+def InfoNCE(view1, view2, temperature: float, b_cos: bool = True):
36
+ """
37
+ Args:
38
+ view1: (torch.Tensor - N x D)
39
+ view2: (torch.Tensor - N x D)
40
+ temperature: float
41
+ b_cos (bool)
42
+
43
+ Return: Average InfoNCE Loss
44
45
if b_cos:
46
view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)
- pos_score = (view1 * view2).sum(dim=-1)
- pos_score = torch.exp(pos_score / temperature)
- ttl_score = torch.matmul(view1, view2.transpose(0, 1))
- ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
- cl_loss = -torch.log(pos_score / ttl_score+10e-6)
- return torch.mean(cl_loss)
47
48
+ pos_score = (view1 @ view2.T) / temperature
49
+ score = torch.diag(F.log_softmax(pos_score, dim=1))
50
+ return -score.mean()
51
52
53
#this version is from recbole
0 commit comments