难以承受之重
文本生成是 NLP 任务中比较典型的一类,记参数为$\boldsymbol{\theta }$
,给定的 context 为$\boldsymbol{c}$
,需要生成的文本记为$\boldsymbol{w}$
,我们通常通过最大似然法来使得模型预测的分布$p_{\boldsymbol{\theta}}$
尽可能接近训练集分布$p_{d}(\boldsymbol{w})$
$$ \boldsymbol{\theta ^{\ast}} = \mathop{\arg \max}_{\boldsymbol{\theta }} \ \mathbb{E}_{\boldsymbol{c}, \boldsymbol{w} \sim p_{d}} \log p_{\boldsymbol{\theta }}(\boldsymbol{w}|\boldsymbol{c};\boldsymbol{\theta }) $$
而在建模时,我们通常会在模型加入 Softmax 来将 score 转换为概率,使得对词表$\mathcal{V}$
所有词预测概率相加为 1:
$$ p(\boldsymbol{w}|\boldsymbol{c};\boldsymbol{\theta}) = \frac{u_{\boldsymbol{\theta }}(\boldsymbol{w}, \boldsymbol{c})}{\underbrace{ {\color{blue}\sum_{\boldsymbol{w}' \in \mathcal{V}} u_{\boldsymbol{\theta}}(\boldsymbol{w}', \boldsymbol{c})} }_{ \text{Partition Function} }}, u_{\boldsymbol{\theta }}(\boldsymbol{w}, \boldsymbol{c}) = \exp(s_{\theta }(\boldsymbol{w}, \boldsymbol{c})) $$
分母是用来归一化的,也被称作配分函数(Partition Function),为了使得表达更简便,我们将上述公式进一步压缩,将分母统称为$Z(\boldsymbol{\theta })$
若是普通的多分类问题,参数量不大,求$Z(\boldsymbol{\theta })$
感觉不到压力,可若是文本生成任务,例如,「文本____是自然语言处理的任务」,此时你得去整个词表$\mathcal{V}$
中来挑选词来填空,去计算$Z(\boldsymbol{\theta })$
就是十分昂贵的事情了
丢给参数
那么,有人就说了,不行那就直接交给参数处理吧,让模型自己去学,看看模型自己能不能学出归一化:
$$ p(\boldsymbol{w}|\boldsymbol{c};\boldsymbol{\theta}) = \frac{u_{\boldsymbol{\theta }}(\boldsymbol{w}, \boldsymbol{c})}{Z(\boldsymbol{\theta })} = u_{\boldsymbol{\theta }}(\boldsymbol{w}, \boldsymbol{c})\exp(z^{\boldsymbol{c}}), \, z^{\boldsymbol{c}} = -\log Z(\boldsymbol{\theta }) $$
接着应用最大似然:
$$ \begin{align} \boldsymbol{\theta ^{\ast}} & = \mathop{\arg \max}_{\boldsymbol{\theta }} \ \mathbb{E}_{\boldsymbol{c}, \boldsymbol{w} \sim p_{d}} \log p_{\boldsymbol{\theta }}(\boldsymbol{w}|\boldsymbol{c};\boldsymbol{\theta }) \\ &= \mathop{\arg \max}_{\boldsymbol{\theta }, \boldsymbol{z}} \ \mathbb{E}_{\boldsymbol{c}, \boldsymbol{w} \sim p_{d}} \log u_{\boldsymbol{\theta }}(\boldsymbol{w},\boldsymbol{c})\exp(z^{\boldsymbol{c}}) \end{align} $$
这样的结果就是,为了最大化期望,会使得$Z(\boldsymbol{\theta }) \to 0$
,效果会很不好
曲径通幽
那么 Noise Contrastive Estimation(NCE)说,既然这样,我们能不能引入参数的同时也可以出色地预估$Z(\boldsymbol{\theta })$
呢?于是乎,它将问题从原本的多分类问题转换为二分类问题
Proxy Problem,指用新的任务或指标来完成对原本任务的建模
,具体如下:
首先,存在一个噪声分布$p_{n}$
和经验概率分布$p_{d}$
,这里$p_{d}$
是从训练集提取的,就类似 word2vec 的训练,将句子切分成词
现代 NLP 基本都是 token,这里是为了表达简便
,统计某几个词一起出现的概率,那么$p(\boldsymbol{w}|\boldsymbol{c})$
就是对于$\boldsymbol{c}$
而言,下一个词是$\boldsymbol{w}$
的概率。举个例子,对于 love 而言:$p_{d}=\{ \text{games}: 0.9, \text{study}: 0.1 \}$
每次从$p_{d}$
中抽出一个候选词,从$p_{n}$
中抽取$k$
个候选词。模型的任务即为区分候选词是从训练集还是噪声中采样而来的,通过这个代理任务使得$p_{\boldsymbol{\theta}}(\boldsymbol{c})$
去逼近于$p_{d}(\boldsymbol{c})$
我们规定,当$\mathcal{D}=1$
时代表从训练集采样,而$\mathcal{D}=0$
则代表从噪声中采样,那么:
$$ \begin{align} p(\boldsymbol{w}|\mathcal{D} =1, \boldsymbol{c}) & = p_{d}(\boldsymbol{w})\\ p(\boldsymbol{w}|\mathcal{D} = 0, {\boldsymbol{c}}) & = p_{n}({\boldsymbol{w}}) \end{align} $$
那么总概率即为:
$$ p_{joint}(\boldsymbol{w}) =\frac{1}{k+1}p_{d}(\boldsymbol{w}) + \frac{k}{k+1} p_{n}(\boldsymbol{w}) $$
接下来求一下来自哪个采样的条件概率:
$$ \begin{align} p(\mathcal{D}=1|\boldsymbol{w}) & = \frac{p(\mathcal{D}=1,\boldsymbol{w})}{p_{joint}(\boldsymbol{w})} = \frac{p(\boldsymbol{w}|\mathcal{D}=1)p(\mathcal{D}=1)}{p_{joint}(\boldsymbol{w})} \\ & = \frac{p_{d}(\boldsymbol{w})}{p_{d}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \end{align} $$
同理:
$$ p(\mathcal{D}=0|\boldsymbol{w}) = \frac{kp_{n}(\boldsymbol{w})}{p_{d}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} $$
$\mathcal{D}$
代表训练集和噪声集的集合,那么 NCE 的目标即为最大化期望:
$$ \boldsymbol{\theta }^{\ast} = \mathop{\arg \max}_{\boldsymbol{\theta }}\ \mathbb{E}_{\boldsymbol{w} \sim \mathcal{D}} \log p(\mathcal{D}|\boldsymbol{w}) $$
又因为我们想要让模型分布$p_{\boldsymbol{\theta }}$
尽可能接近训练集分布$p_{d}$
,于是我们在求条件概率时,将$p_{d}$
换成$p_{\boldsymbol{\theta }}$
,即:
$$ \begin{align} p(\mathcal{D}=1|\boldsymbol{w}) & = \frac{p_{\boldsymbol{\theta }}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \\ p(\mathcal{D}=0|\boldsymbol{w}) & = \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \end{align} $$
我们展开期望看看:
$$ \begin{align} \mathbb{E}_{\boldsymbol{w}\sim \mathcal{D}} \log p(\mathcal{D}|\boldsymbol{w}) & = \int_{\boldsymbol{w}} p(\boldsymbol{w})\log p(\mathcal{D}|\boldsymbol{w}) \, d\boldsymbol{w} \\ &= \int_{\boldsymbol{w}} \frac{1}{k+1}(p_{d}(\boldsymbol{w})+kp_{n}(\boldsymbol{w}))\log p(\mathcal{D}|\boldsymbol{w})\, d \boldsymbol{w} \\ &= \frac{1}{k+1} \left( \int_{\boldsymbol{w}} p_{d}(\boldsymbol{w}) \log p(\mathcal{D}=1|\boldsymbol{w}) \, d \boldsymbol{w} + \int _{\boldsymbol{w}} k p_{n}(\boldsymbol{w}) \log p(\mathcal{D}=0|\boldsymbol{w}) \, d\boldsymbol{w} \right) \\ &= \frac{1}{k+1}\bigg(\mathbb{E}_{\boldsymbol{w} \sim p_{d}} \log p(\mathcal{D}=1|\boldsymbol{w})+k\mathbb{E}_{\boldsymbol{w} \sim p_{n}}\log p(\mathcal{D}=0|\boldsymbol{w})\bigg) \end{align} $$
上述期望计算初看肯定有两处疑问:
第一、为啥要基于$\boldsymbol{w}$
而非$\boldsymbol{c}$
来展开概率计算呢?当然两者都可以,但是我们的目标是为了让模型分布去拟合训练集分布,若是按照$\boldsymbol{c}$
展开,也就是$p_{\boldsymbol{\theta }}(\boldsymbol{c})\approx p_{d}(\boldsymbol{c})$
,让模型预测输入的 feature 不合理
第二、不是说好用模型分布代替数据分布吗?为什么$p(\boldsymbol{w})$
还是用的数据分布,我想可能是为了训练方便考虑,若两处都是模型分布,训练势必更难;同时,数据分布是一个既定事实,可以充当额外的信息量给模型,加快收敛
因为$k$
是常数,对优化目标函数无影响,下式省略之,那么,我们的目标函数即为:
$$ J(\boldsymbol{\theta }) =\mathbb{E}_{\boldsymbol{w} \sim p_{d}} \log p(\mathcal{D}=1|\boldsymbol{w})+k\mathbb{E}_{\boldsymbol{w} \sim p_{n}}\log p(\mathcal{D}=0|\boldsymbol{w}) $$
极限的视角
你肯定好奇 Proxy Problem 是否可以近似原来的建模,目标函数相对于$\boldsymbol{\theta }$
的微分告诉了我们答案
$$ J(\boldsymbol{\theta }) =\mathbb{E}_{\boldsymbol{w} \sim p_{d}} \log p(\mathcal{D}=1|\boldsymbol{w})+k\mathbb{E}_{\boldsymbol{w} \sim p_{n}}\log p(\mathcal{D}=0|\boldsymbol{w}) $$
那么,我们求关于参数$\boldsymbol{\theta }$
的微分:
$$ \begin{align} \frac{ \partial }{ \partial \boldsymbol{\theta } } J(\boldsymbol{\theta }) & = \frac{ \partial }{ \partial \boldsymbol{\theta } } \mathbb{E}_{\boldsymbol{w} \sim p_{d}} \log p(\mathcal{D}=1|\boldsymbol{w})+\frac{ \partial }{ \partial \boldsymbol{\theta } } k\mathbb{E}_{\boldsymbol{w} \sim p_{n}}\log p(\mathcal{D}=0|\boldsymbol{w}) \\ &= \mathbb{E}_{\boldsymbol{w}\sim p_{d}} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log \frac{p_{\boldsymbol{\theta }}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} + k\mathbb{E}_{\boldsymbol{w}\sim p_{n}} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \end{align} $$
那么接下来我们拆开来求目标函数相对于参数的微分:
$$ \begin{align} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log \frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} & = \frac{p_{\theta }(\boldsymbol{w})+kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \\ &= \frac{p_{\theta }(\boldsymbol{w})+kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})} \frac{p_{\boldsymbol{\theta}}'(\boldsymbol{w})(p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w}))-p_{\boldsymbol{\theta}}(\boldsymbol{w})p_{\boldsymbol{\theta}}'(\boldsymbol{w})}{(p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w}))^{2} } \\ &= \frac{p_{\boldsymbol{\theta }}'(\boldsymbol{w})kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})(p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w}))} \\ &= \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} {\color{blue}\frac{p_{\boldsymbol{\theta }}'(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})}} \\ &=\frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta}}(\boldsymbol{w}) \end{align} $$
另一部分:
$$ \begin{align} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} & =\frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})}{kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \\ &= \frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})}{kp_{n}(\boldsymbol{w})} \frac{0-kp_{n}(\boldsymbol{w})p_{\boldsymbol{\theta}}'(\boldsymbol{w})}{(p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w}))^{2}} \\ &= -\frac{p_{\boldsymbol{\theta}}'(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} {\color{blue}\frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})}}\\ &= - \frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta}}(\boldsymbol{w}) \end{align} $$
合起来看看:
$$ \begin{align} \frac{ \partial }{ \partial \boldsymbol{\theta } } J(\boldsymbol{\theta }) & = \mathbb{E}_{\boldsymbol{w} \sim p_{d}} \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta}}(\boldsymbol{w}) -k \mathbb{E}_{\boldsymbol{w} \sim p_{n}}\frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta}}(\boldsymbol{w}) \\ &= \sum_{\boldsymbol{w}} p_{d}(\boldsymbol{w}) \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta }}(\boldsymbol{w}) -k \sum_{\boldsymbol{w}}p_{n}(\boldsymbol{w})\frac{p_{\boldsymbol{\theta}}(\boldsymbol{w})}{p_{\boldsymbol{\theta}}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta}}(\boldsymbol{w}) \\ &= \sum_{\boldsymbol{w}} \underbrace{ {\color{blue}\frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})}} }_{ k \to \infty, ratio \to 1 }\bigg(p_{d}(\boldsymbol{w}) - p_{\boldsymbol{\theta }}(\boldsymbol{w})\bigg) \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta }}(\boldsymbol{w}) \\ &\approx \sum_{\boldsymbol{w}}\bigg(p_{d}(\boldsymbol{w}) - p_{\boldsymbol{\theta }}(\boldsymbol{w})\bigg) \frac{ \partial }{ \partial \boldsymbol{\theta } } \log p_{\boldsymbol{\theta }}(\boldsymbol{w}) \end{align} $$
对于第四步的近似,可以举个例子,比如$10000/(1+10000)$
,其实因为$1$
太小了,可以忽略不计
这也是为什么这个 Proxy Problem 可以 work 的原因,当采样的噪声样本足够多时,NCE 的梯度就接近于一开始我们想要直接去做最大似然的梯度
两次近似
尽管通过引入参数可以去估计$Z(\boldsymbol{\theta })$
很巧妙,但有一个很大的问题,对于每一组词而言,尽管$\mathcal{V}$
是一样的,然而基于的$\boldsymbol{c}$
不一致,那么$p(\boldsymbol{w}|\boldsymbol{c})$
也是不同的,即每组词你得去保存一个参数$z^{\boldsymbol{c}}$
这个时候作者
不是NCE提出论文,而是http://arxiv.org/abs/1206.6426
发现了「神之一手」,直接令$Z(\boldsymbol{\theta })\approx 1$
,也就是俗称的 self-normalization,换句话说,压根没有转换为概率,你看到这肯定会露出不屑的表情,我也一样
自归一化 work 的原因是什么呢?引用原著的说法:
We believe this is because the model has so many free parameters that meeting the approximate per-context normalization constraint encouraged by the objective function is easy.
作者的意思就是参数很多,于是就有了 power,模型自己可以去学习归一化,当然,原著中做了对比,发现效果几乎没影响,才这么做的
其实我看来还是目标函数选的好,因为当梯度近似为$0$
时,$p_{d}$
和$p_{\boldsymbol{\theta }}$
很接近
这里其实有个容易误解的点,其实这个目标函数只是为了拟合一对词$(\boldsymbol{c}, \boldsymbol{w})$
:
$$ J(\boldsymbol{\theta }) =\mathbb{E}_{\boldsymbol{w} \sim p_{d}} \log p(\mathcal{D}=1|\boldsymbol{w})+k\mathbb{E}_{\boldsymbol{w} \sim p_{n}}\log p(\mathcal{D}=0|\boldsymbol{w}) $$
对于每组词都要计算期望,即考虑所有候选可能太过奢侈,所以原著进行了第二次近似,也有一些资料是说抽取$k$
个是蒙特卡洛模拟的一种
$$ J^{\boldsymbol{c}}(\boldsymbol{\theta }) = \log p(\mathcal{D}=1|\boldsymbol{w}_{0}) + \sum_{i=1}^{k} \log p(\mathcal{D}=0|\boldsymbol{w}_{i}) $$
那么对于所有的词组该如何建模呢?我们定义一个全局 NCE 进行优化就行了:
$$ J(\boldsymbol{\theta }) = \sum_{\boldsymbol{c}} p(\boldsymbol{c)}J^{\boldsymbol{c}}(\boldsymbol{\theta }) $$
sigmoid 客串
当然,如果你看现在很多机器学习库的实现,你会发现跟上面的式子可能有点不一样?
进行变形一下:
$$ \begin{align} p(\mathcal{D}=1|\boldsymbol{w}) & = \frac{p_{\boldsymbol{\theta }}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})+kp_{n}(\boldsymbol{w})} \\ &= \frac{1}{1+ \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})}} \\ &= \frac{1}{1+ \exp\left(\log \left( \frac{kp_{n}(\boldsymbol{w})}{p_{\boldsymbol{\theta }}(\boldsymbol{w})} \right)\right)} \\ &= \frac{1}{1+\exp(\log kp_{n}(\boldsymbol{w})-\log p_{\boldsymbol{\theta }}(\boldsymbol{w}))} \\ &= \frac{1}{1+\exp({\color{red}-}(\underbrace{ \log p_{\boldsymbol{\theta }}(\boldsymbol{w})-\log kp_{n}(\boldsymbol{w})) }_{ x })} \end{align} $$
将里面看成一个参数$x$
,那么就看到了 sigmoid 函数:
$$ p(\mathcal{D}=1|\boldsymbol{w}) = \sigma(\log p_{\boldsymbol{\theta}}(\boldsymbol{w})-\log kp_{n}(\boldsymbol{w})) $$
同理:
$$ p(\mathcal{D}=0|\boldsymbol{w}) = \sigma(\log kp_{n}(\boldsymbol{w}) - \log p_{\boldsymbol{\theta}}(\boldsymbol{w})) $$
损失函数基本就呼之欲出了:
$$ \begin{align} L(\boldsymbol{\theta }) & = - \sum_{\boldsymbol{c}} p(\boldsymbol{c})\left( \log p(\mathcal{D}=1|\boldsymbol{w}_{0})+\sum_{i=1}^{k} \log p(\mathcal{D}=0|\boldsymbol{w}_{i}) \right) \\ &= -\sum_{\boldsymbol{c}} p(\boldsymbol{c}) \left( \log \sigma(\log p_{\boldsymbol{\theta}}(\boldsymbol{w}_{0})-\log kp_{n}(\boldsymbol{w}_{0})) + \sum_{i=1}^{k} \log \sigma(\log kp_{n}(\boldsymbol{w}_{i})- \log p_{\boldsymbol{\theta}}(\boldsymbol{w}_{i})))\right) \\ &= -\sum_{\boldsymbol{c}} p(\boldsymbol{c})\bigg(\log\sigma(s_{\boldsymbol{\theta }}(\boldsymbol{c}, \boldsymbol{w}_{0} )-\log kp_{n}(\boldsymbol{w}_{0}))+\sum_{i=1}^{k} \log \sigma(\log kp_{n}(\boldsymbol{w}_{i})- s_{\boldsymbol{\theta }}(\boldsymbol{c}, \boldsymbol{w}_{i}))\bigg) \end{align} $$
上代码
这里的$p_{n}$
选的是 log-uniform,类别越往后出现的概率就越小,所以如果要用,可以将类别按照数目进行排序,将多的放在前面,举个例子,类别 A, B, C, D 分别出现的数目为 10, 20, 100, 15,那么类别排序就应该是 C B D A,类别 0 对应的就是 C。range_max 对应的就是类别总数,这里就是 4
$$ \log_{uniform}(class) = \frac{\log(class + 2) - \log (class + 1)}{\log(range_{max} + 1)} $$
下面是训练的 loss,eval 的时候没有 noise,找出 labels 对应的 logits,然后算指标就行了,同时,这里考虑数值稳定性,用 pytorch 官方的 softplus 来取代 logsigmoid,详见Numerical Stability
import math
from einops import repeat
import torch.nn.functional as F
from torch import arange, randn, tensor, log, multinomial
def nce_loss(logits_pos, logits_neg, log_pn_pos, log_pn_neg, k):
"""Compute the noise contrastive estimation loss in
https://arxiv.org/abs/1806.03664.
Params:
- logits_pos: Tensor. Shape: (bs, 1). Logits corresponding to labels.
- logits_neg: Tensor. Shape: (bs * k, 1). Logits corresponding to sampled classes.
- log_pn_pos: Tensor. Shape: (bs, 1). Log-probability of labels sampled from noise distribution.
- log_pn_neg: Tensor. Shape: (bs * k, 1). Log-probability of noise candidates sampled from noise distribution.
- k: int. The number of noise candidates per training example.
Note:
This implementation assumes each context is equally shown which leads to final averge."""
logk = math.log(k)
# For numerical stability, replace logsigmoid by the torch softplus
# for it considers the overflow situation.
# log(sigmoid(x)) = -softplus(-x)
# final return also contains minus(-), thus remove all the minus(-)
pos = F.softplus((logk + log_pn_pos) - logits_pos).mean()
neg = F.softplus(logits_neg - (logk + log_pn_neg)).mean()
return pos + neg
def log_uniform(num_sampled, range_max, replacement=True):
"""Sample classes from log-uniform distribution.:
p(class) = (log(class + 2) - log(class + 1)) / (log(range_max + 1)).
sampled_classes: [0, range_max).
Also note that the data distribution should follow the log_uniform.
e.g., the classes should be in decreasing order of frequencey when in text generation.
Params:
- num_sampled: int. The number to be sampled.
- range_max: int. The number of total classes.
- replacement: bool. If false, sampled candidates are unique.
Examples:
>>> log_uniform(2, 10)
>>> # tensor([7, 2])
"""
classes = arange(0, range_max)
probs = log((classes + 2) / (classes + 1)) / math.log(range_max + 1)
return probs, multinomial(probs, num_sampled, replacement=replacement)
def main():
bs, k = 2, 4
num_classes = 8
logits = randn(bs, num_classes)
labels = tensor([2, 4])
probs, noise_classes = log_uniform(bs * k, num_classes)
logits_pos = logits.take_along_dim(labels[:, None], dim=1)
log_pn_pos = probs[labels]
log_pn_neg = probs[noise_classes]
logits_k = repeat(logits, '(b 1) h -> (b k) h', k=k)
logits_neg = logits_k.take_along_dim(noise_classes.reshape(bs * k, -1), dim=1)
loss = nce_loss(logits_pos, logits_neg, log_pn_pos, log_pn_neg, k)
print('nce loss: %f' %loss)
if __name__ == '__main__':
main()
至于实验,先鸽一下,留在后面与 info-nce,negative-sampling 等做对比