ICML23 - Explore and Exploit the Diverse Knowledge in Model Zoo for Domain Generalization

Posted by Lucius on September 5, 2023

1. 概述

文章出发点:

  • 近年来预训练模型库兴起,如何有效利用模型库中的信息,获取对下游任务有效的模型,成为重要的研究方向;

  • 先前的研究主要关注于如何识别模型库中最有效的模型,因此未充分利用模型库中多样的归纳偏好;

  • 本文认为 weaker models 中的知识也非常有价值,并提出通过刻画「feature diversity shift」和「feature correlation shift」来利用这些多样性的方法。

文章关注点:

  • 在领域泛化 (Domain generalization, DG) 中, 模型在多个 source domains 上训练得到,其目的是提升它们在 unseen domains 上的泛化能力(与 Domain adaptation 不同,在 DG 问题中,无法获得测试数据);

  • 本文期望利用 model zoo 的能力,包括其中的 weaker models,使得 DG 的性能得到提升。

领域泛化的基本设定:

  • 训练集包括多个不同 domain 上的不同分布的数据集(输入、输出空间一致的数据集);

  • 模型在训练集上训练,希望其在未见数据(测试数据,与训练数据不同 domain,不同分布)上表现优异;

  • 下图为 PACS 数据集示例:

本文方法的整体思路:

  • 使用 SOTA 的查搜方法,根据测试集在预训练模型库中选出 Top-1 模型 (main model);

  • 再根据本文给出的度量选择一些较差的模型 (auxiliary model),希望它们能指导 Top-1 模型在训练集上的微调方式(相当于提供错误示例);

  • 希望最终微调后的 Top-1 模型能在测试集上表现优异。

2. Model Exploration

2.1 Feature Diversity and Correlation Shifts

包含多个 domain ($\mathcal{E}$) 的数据集为 $\mathcal{D}=\left\{D_e\right\}_{e \in \mathcal{E}}$,其中 $D_e=\left\{x_i^e, y_i^e\right\}_{i=1}^{n^e}$ $i.i.d.$ 采样于 $\mathbb{P}^e(\mathcal{X} \times \mathcal{Y})$。

将 PTM (Pre-Training Model) 看作特征提取器($\phi: \mathcal{X} \rightarrow \mathcal{Z}_\phi$),我们关注的是 PTMs 在不同 domain $(\forall e\in \mathcal{E})$ 上 $\mathbb{P}^e(\phi(X),Y)$ 的差异。

由于 $\mathbb{P}^e(\phi(X),Y)=\mathbb{P}^e(Y \mid \phi(X)) \mathbb{P}^e(\phi(X))$,在本文中,$\mathbb{P}^e(Y \mid \phi(X))$ 被称为 feature correlation shift,$\mathbb{P}^e(\phi(X))$ 被称为 feature diversity shift.

本文使用下述两个 metrics 来度量上述两种 shift ($\phi: \mathrm{x} \mapsto \mathrm{z}$):

$$ \begin{aligned} & F_{div }\left(\phi, e, e^{\prime}\right)=\frac{1}{2} \int_{\mathcal{S}}\left|p_e(\mathbf{z})-p_{e^{\prime}}(\mathbf{z})\right| \mathrm{d} \mathbf{z}, \\ & F_{cor }\left(\phi, e, e^{\prime}\right)=\frac{1}{2} \int_{\mathcal{T}} \tilde{p}_{e, e^{\prime}}(\mathbf{z}) \sum_{y \in \mathcal{Y}}\left|p_e(y \mid \mathbf{z})-p_{e^{\prime}}(y \mid \mathbf{z})\right| \mathrm{d} \mathbf{z}, \end{aligned} $$

其中 $\tilde{p}_{e, e^{\prime}}$ 是 $p_e,p_{e'}$ 的几何平均数,且 $\mathcal{S},\mathcal{T}$ 定义如下:

$$ \begin{aligned} & \mathcal{S}\left(\phi, e, e^{\prime}\right):=\left\{\mathbf{z} \in \mathcal{Z}_\phi \mid p_e(\mathbf{z}) \cdot p_{e^{\prime}}(\mathbf{z})=0\right\}, \\ & \mathcal{T}\left(\phi, e, e^{\prime}\right):=\left\{\mathbf{z} \in \mathcal{Z}_\phi \mid p_e(\mathbf{z}) \cdot p_{e^{\prime}}(\mathbf{z}) \neq 0\right\}. \end{aligned} $$

因此 $F_{div}$ 刻画了特征 $(\phi(\mathbf{x}))$ 在两个 domain 中不会共同出现的部分,$F_{cor}$ 刻画了特征 $(\phi(\mathbf{x}))$ 与目标 $(y)$ 之间的相关性 $(p(y\mid \mathbf{z}))$ 在不同 domain 中的差异。

2.2 Practical Estimation

经验估计 $F_{div }$

$F_{div }$ 实际上就是 domain $e$ 和 $e’$ 在对方支撑集之外的概率密度积分之和,可以简化为下式:

$$ F_{\text {div }}\left(\phi, e, e^{\prime}\right)=\frac{1}{2}\left(\mathbb{P}^e\left[\mathcal{S}_e\left(e^{\prime}, \phi\right)\right]+\mathbb{P}^{e^{\prime}}\left[\mathcal{S}_{e^{\prime}}(e, \phi)\right]\right), $$

其中 $\mathcal{S}_e\left(e^{\prime}, \phi\right)=\left\{\mathbf{z} \in \mathcal{Z}_\phi \mid p_e(\mathbf{z})>0,p_{e'}(\mathbf{z})=0\right\}$。为了对其做经验估计,可以遍历 domain 对应数据集中的样本,并认为样本概率小于 $\epsilon_{e}$ 时即为 0,即:

$$ \hat{\mathbb{P}}^e\left[\hat{\mathcal{S}}_e\left(e^{\prime}, \phi\right)\right]=\hat{\mathbb{P}}^e\left(\left\{\mathbf{x} \in D_e \mid \hat{p}_{e^{\prime}}(\mathbf{z})<\epsilon_{e^{\prime}}, \mathbf{z}=\phi(\mathbf{x})\right\}\right). $$

其中 $\epsilon_{e’}$ 根据下式确定:

$$ \hat{\mathbb{P}}^{e^{\prime}}\left(\left\{\mathrm{x} \in D_{e^{\prime}} \mid \hat{p}_{e^{\prime}}(\mathbf{z})<\epsilon_{e^{\prime}}, \mathbf{z}=\phi(\mathrm{x})\right\}\right)=0.01. $$

另外 $p_e$ 假定为高斯分布 $\mathcal{N}\left(\mu_e, \Sigma_e\right)$,并根据数据集样本估计分布参数。

经验估计 $F_{cor}$

首先定义 $\mathcal{T}$ 的经验集合:

$$ \hat{\mathcal{T}}\left(\phi, e, e^{\prime}\right)=\left(D_e \backslash \hat{S}_e\left(e^{\prime}, \phi\right)\right) \cup\left(D_{e^{\prime}} \backslash\left.\hat{S}_{e^{\prime}}(e, \phi)\right)\right.. $$

随后可以得到 $F_{cor}$ 的经验版本:

$$ \hat{F}_{c o r}\left(\phi, e, e^{\prime}\right)=\frac{1}{2} \sum_{\mathbf{x} \in \mathcal{\mathcal { T }}} \hat{p}_{e, e^{\prime}}(\mathbf{x}) \sum_{y \in \mathcal{Y}}\left|\hat{p}_e(y \mid \phi(\mathbf{x}))-\hat{p}_{e^{\prime}}(y \mid \phi(\mathbf{x}))\right|. $$

因此接下来就是对 $\hat{p}_{e^{\prime}}(y \mid \phi(\mathbf{x}))$ 进行估计,此处使用了「极大化证据」技术(与 LogME 中一致),此处不再赘述。

唯一的一点不同在于,本文使用「极大化证据」技术对 $\tilde{p}_{e^{\prime}}(y \mid \phi(\mathbf{x}))$ 进行估计,并对其做了简单的校正(详情见论文附录,个人认为并不重要),得到最终的 $\hat{p}_{e^{\prime}}(y \mid \phi(\mathbf{x}))$。

具体计算

论文中具体计算时,每个模型的两种 shift 为所有 domain pair 对上的均值。

3. Model Zoo Exploitation

该部分内容主要是讨论:如何利用上述两种 shift 来集成模型。

3.1 Diversity Ensemble Method

Hilbert-Schmidt independence criterion (HSIC)

此处引入了一个新技术:Hilbert-Schmidt independence criterion (HSIC),主要用于衡量两个变量之间的独立性,其有如下结果:

$$ \text{HSIC}(X, Y)=0 \Leftrightarrow p(x, y) \equiv p(x) p(y). $$

由于当 $p(x, y) \equiv p(x) p(y)$ 时,对于任意函数 $f,g$,均有:

$$ C[f,g]=\mathbb{E}_{(x, y) \sim p(x, y)}[f(x) g(y)]-\mathbb{E}_{x \sim p(x)}[f(x)] \mathbb{E}_{y \sim p(y)}[g(y)]=0. $$

因此计算 $\text{HSIC}(X, Y)$ 的思路为:遍历所有可能的 $f,g$,计算 $L_H=\sum_{f,g}(C[f,g])^2$,其中:

$$ \begin{aligned} (C[f, g])^2= & \mathbb{E}_{\left(x_1, y_1\right) \sim p(x, y),\left(x_2, y_2\right) \sim p(x, y)}\left[f\left(x_1\right) f\left(x_2\right) g\left(y_1\right) g\left(y_2\right)\right] \\ & +\mathbb{E}_{x_1 \sim p(x), x_2 \sim p(x), y_1 \sim p(y), y_2 \sim p(y)}\left[f\left(x_1\right) f\left(x_2\right) g\left(y_1\right) g\left(y_2\right)\right] \\ & -2 \mathbb{E}_{\left(x_1, y_1\right) \sim p(x, y), x_2 \sim p(x), y_2 \sim p(y)}\left[f\left(x_1\right) f\left(x_2\right) g\left(y_1\right) g\left(y_2\right)\right]. \end{aligned} $$

由于无法遍历所有的 $f$、$g$,因此此处分别定义了 X、Y 变量上的 kernel,通过遍历 kernel 正交基的方式(Mercer’s theorem)进行求解。

转换后,对应的 HSIC 为:

$$ \begin{aligned} \text{HSIC}(X, Y)= & \mathbb{E}_{\left(x_1, y_1\right) \sim p(x, y),\left(x_2, y_2\right) \sim p(x, y)}\left[K_X\left(x_1, x_2\right) K_Y\left(y_1, y_2\right)\right] \\ & +\mathbb{E}_{x_1 \sim p(x), x_2 \sim p(x), y_1 \sim p(y), y_2 \sim p(y)}\left[K_X\left(x_1, x_2\right) K_Y\left(y_1, y_2\right)\right] \\ & -2 \mathbb{E}_{\left(x_1, y_1\right) \sim p(x, y), x_2 \sim p(x), y_2 \sim p(y)}\left[K_X\left(x_1, x_2\right) K_Y\left(y_1, y_2\right)\right]. \end{aligned} $$

其经验版本(有偏估计)为:

$$ \text{HSIC}(X, Y)=\frac{1}{n^2}\text{Tr}(K_XJK_YJ), $$

其中 $K_X$、$K_Y$ 分别为对应的核矩阵,$J=I-1/n$,$I$ 为 $n$ 阶单位阵。相应的无偏估计版本为:

$$ \text{HSIC}(X, Y)=\frac{1}{(n-1)^2}\text{Tr}(K_XJK_YJ). $$
本文方法

首先根据 ZooD(该文章团队先前一种挑选模型的方法)选出 Top-1 模型,称为 main model $f_M$.

随后,再根据前文两种 shift 的度量,选出 diversity shift 程度最高的模型,称为 auxiliary model $f_d$.

随后采用下式作为 $f_M$ 的训练目标(微调):

$$ \begin{aligned} \mathcal{L}\left(f_M\right):=\min _{f_M} \mathbb{E}_{X, Y \sim \mathbb{P}_{\mathcal{D}}} \left[\mathcal{L}_c\left(Y, f_M(X)\right) +\lambda \operatorname{HSIC}_d\left(f_M, f_d\right)\right] . \end{aligned} $$

其中 $\operatorname{HSIC}_d\left(f_M, f_d\right)$ 刻画的是「$f_M$ 加上分类头的输出」与「$f_d$ 的输出」之间的相关性,希望相关性越小越好。

文中认为此举可以尽量减少 diversity shift 带来的失败。

3.2 Correlation Ensemble Method

与上文不同,此处通过 correlation shift 度量,选择 auxiliary model $f_c$,而 main model $f_M$ 的选择方式不变。

该部分采用的是一种去除数据集偏差的方法 (Ensemble-based debiasing (EBD) methods),整体思想是用一个泛化较差的模型拟合训练集,此时得到一个 bias-only model。随后再利用 bias-only model 指导 main model 的训练,使得 main model 不会在这些 dataset bias 处过拟合,导致泛化性能下降。

具体来说,本文首先将 $f_c$ 在训练集上微调,随后利用 $f_c$ 输出的 logits 对训练集数据加权,依次来指导 main model 的训练,训练目标如下所示:

$$ \mathcal{L}_{\mathcal{B}}\left(f_M\right):=\frac{1}{|\mathcal{B}|} \sum_{(\mathbf{x}, y) \in \mathcal{D}} m\left(\frac{p(y)}{p_c(\mathbf{x})_y \cdot T}\right) \mathcal{L}_c\left(y, f_M(\mathbf{x})\right), $$

其中 $\mathcal{B}$ 代表一个 batch 的大小,$p_c(\mathbf{x})_y$ 表示 $f_c$ 在样本 $\mathbf{x}$ 上输出的关于类别 $y$ 的概率。整体思想是希望 main model 专注那些 $f_c$ 预测不好的数据(更可能带来泛化性能的数据),降低 $f_c$ 表现好的样本(此类样本被视为是数据集本身的一些 bias)的权重。

参考资料