Detecting and Correcting for Label Shift with Black Box Predictors(BBSE)

Posted by Lucius on October 5, 2022

概述

首先从一个流感的例子讲起,医院在八月根据当月数据训练了模型 $f$,假设其特征 $\boldsymbol{x}$ 为「有无咳嗽」,预测标签 $y$ 为「有无得流感」。

后续几个月模型 $f$ 运转良好,但到第二年二月时,医院发现 $f$ 预测为「得流感」的人数大幅增加,此时我们知道这与「冬季是流感高发期」有关。但一个问题随即出现了,用八月数据训出的 $f$ 是否在二月也能有效预测,其在八月数据上学得的先验是否会影响二月时的判断。

将问题形式化,我们可以发现八月和二月的 $p(y\mid \boldsymbol{x})=p($流感 $\mid$ 咳嗽$)$ 和 $p(y)=p($流感$)$ 明显发生了变化,因此过往在「covariate shift」上的研究不再适用。

继续深入,我们可以发现 $p(\boldsymbol{x}\mid y)=p($咳嗽 $\mid$ 流感$)$ 似乎并没有发生太大的变化,由此引入本篇文章所关注的「label shift」问题,其代表下述这种情况:

  • 标签边际分布 $p(y)$ 发生变化,但条件分布 $p(\boldsymbol{x}\mid y)$ 不变

随后文中提出「Black Box Shift Estimation (BBSE)」方法,利用「黑盒预测器」来估计变化的 $p(y)$,且仅要求其对应「混淆矩阵 (confusion matrices)」是可逆的,即使预测器是 biased,inaccurate 或 uncalibrated。

问题设定

源域:$\mathcal{X}\times \mathcal{Y}$ 上的分布 $P$,$D=\{(\boldsymbol{x}_i, y_i)\}_{i=1}^n$,基于 $D$ 训练得到的黑盒模型 $f:\mathcal{X}\rightarrow \mathcal{Y}$

目标域:$\mathcal{X}\times \mathcal{Y}$ 上的分布 $Q$,$X’=[\boldsymbol{x}_1’;…;\boldsymbol{x}_m’]$

目标:检测 $P\rightarrow Q$ 是否发生了「label shift」,若发生了则重新训练模型,使其适应分布 $Q$

三大假设

  • 「label shift / target shift」假设:
$$ p(\boldsymbol{x} \mid y)=q(\boldsymbol{x} \mid y) \quad \forall x \in \mathcal{X}, y \in \mathcal{Y} $$
  • $\forall y\in \mathcal{Y}$,若 $q(y)>0$ 则 $p(y)>0$

  • $f$ 对应的混淆矩阵 (confusion matrix) $\mathrm{C}_p(f)$ 可逆,矩阵定义如下:

$$ \mathbf{C}_P(f):=p(f(x), y) \in \mathbb{R}^{|\mathcal{Y}| \times|\mathcal{Y}|} $$

BBSE

「Black Box Shift Estimation (BBSE)」方法主要用于估计 $w(y)=q(y)/p(y)$,其核心思路如下:

$$ \begin{aligned} q(\hat{y}) &=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid y) q(y) \\ &=\sum_{y \in \mathcal{Y}} p(\hat{y} \mid y) q(y)=\sum_{y \in \mathcal{Y}} p(\hat{y}, y) \frac{q(y)}{p(y)} \end{aligned} $$

其中 $\hat{y}$ 即 $f$ 给出的伪标记,而 $q(\hat{y}\mid y)=p(\hat{y}\mid y)$ 则来自于下述推导:

$$ \begin{aligned} &q(\hat{y} \mid y)=\sum_{\boldsymbol{x} \in \mathcal{X}} q(\hat{y} \mid \boldsymbol{x}, y) q(\boldsymbol{x} \mid y)=\sum_{\boldsymbol{x} \in \mathcal{X}} q(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y) \\ &=\sum_{\boldsymbol{x} \in \mathcal{X}} p_f(\hat{y} \mid \boldsymbol{x}) p(\boldsymbol{x} \mid y)=\sum_{\boldsymbol{x} \in \mathcal{X}} p(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y)=p(\hat{y} \mid y) \end{aligned} $$

其关键部分在于 $q(\boldsymbol{x}\mid y)=p(\boldsymbol{x}\mid y)$ 的假设以及 $\hat{y} \perp !!! \perp y \mid \boldsymbol{x}$ 的条件独立性。随后便可以得到:

$$ \begin{gathered} \mu_{\hat{y}}=\mathrm{C}_{\hat{y} \mid y} \mu_y=\mathrm{C}_{\hat{y}, y} w \\ \hat{\boldsymbol{w}}=\hat{\mathbf{C}}_{\hat{y}, y}^{-1} \hat{\boldsymbol{\mu}}_{\hat{y}} \\ \hat{\boldsymbol{\mu}}_y=\operatorname{diag}\left(\hat{\boldsymbol{\nu}}_y\right) \hat{\boldsymbol{w}} \end{gathered} $$

其中各符号定义如下,其核心思想就是本节最开头的公式,只不过为了严谨而引入了大量符号,但实质不变。

理论保障

首先是「Consistency」的保证:

其次是「Error bounds」方面的保证:

根据上述「Error bounds」的结果,可以发现在选择黑盒模型时,「$\mathrm{C}_{\hat{y}, y}$ 最小奇异值」越大的模型越合适。

Label-Shift 检测

在最开头的三大假设下,$q(y)=p(y)\Leftrightarrow p(\hat{y})=q(\hat{y})$,因此使用「two-sample tests」对 $p(\hat{y})=q(\hat{y})$ 进行检测即可。

让模型适应新分布

计算出 $\hat{\boldsymbol{w}}$ 后,采用「importance weighted ERM」在源域数据集 $\mathcal{D}$ 上重新训练模型即可,具体训练目标如下:

$$ \mathcal{L}=\sum_{i=1}^n \hat{w}_i\cdot \ell\left(y_i, \boldsymbol{x}_i\right) $$

整体算法如下:

检测 Label-Shift 假设成立

采用「kernel two-sample tests」检测下述式子是否成立:

$$ \mathbb{E}_p[\boldsymbol{w}(y) k(\phi(\boldsymbol{x}), \cdot)]=\mathbb{E}_q[k(\phi(\boldsymbol{x}), \cdot)] $$

即转化为下述 MMD 距离的计算:

$$ \left\|\frac{1}{n} \sum_{i=1}^n\left[\hat{\boldsymbol{w}}\left(y_i\right) k\left(\phi\left(\boldsymbol{x}_i\right), \cdot\right)\right]-\frac{1}{m} \sum_{j=1}^m k\left(\phi\left(\boldsymbol{x}_j^{\prime}\right), \cdot\right)\right\|_{\mathcal{H}}^2 $$

参考资料