这篇文章把prediction-powered inference (PPI) 和生物统计、计量经济学里的 surrogate outcome model 放在同一个数学框架里。在 AI 时代,黑箱模型给出的预测值 Y ^ \hat Y Y ^ 可以被看作一种廉价、总是可获得的替代结果 (surrogate outcome),但直接把 Y ^ \hat Y Y ^ 当成 Y Y Y 来用通常不是最优的。真正应该校准的对象不是预测值本身,而是目标损失函数的梯度,也就是 imputed loss 的一阶信息。
论文信息
题目:Predictions as Surrogates: Revisiting Surrogate Outcomes in the Age of AI
作者:Wenlong Ji, Lihua Lei, Tijana Zrnic
机构:Stanford University
年份:2025
arXiv:2501.09731
DOI:10.48550/arXiv.2501.09731
关键词:prediction-powered inference, surrogate outcomes, imputed loss, recalibration, control variates
1. Surrogate Outcome Model
我们存在 covariates X X X 和 outcome variable Y Y Y 。参数 θ ⋆ \theta^\star θ ⋆ 定义为以下 Z Z Z -estimator:
E [ U θ ( X , Y ) ] = 0 (1) \mathbb{E}[U_\theta(X,Y)] = 0 \tag{1} E [ U θ ( X , Y )] = 0 ( 1 )
例如,我们可以把 U θ ( X , Y ) U_\theta(X,Y) U θ ( X , Y ) 选择为 score function,即对数似然函数对参数的梯度。
然而,现实中总是存在某些情况,导致收集数据非常耗时甚至对于测量真实的 outcome 是不可能的,这会导致我们的数据 ( X , Y ) (X,Y) ( X , Y ) 非常少。为了增加我们的 sample size,通常的做法是收集一系列的 surrogate outcome Y ^ \hat Y Y ^ ,其中 Y ^ \hat Y Y ^ 与我们感兴趣的 Y Y Y 是相关的,而收集 Y ^ \hat Y Y ^ 的代价很低。因此我们的数据集是 { ( Y i , Y ^ i , X i , D i ) } i = 1 n + N \{(Y_i, \hat Y_i, X_i, D_i)\}_{i=1}^{n+N} {( Y i , Y ^ i , X i , D i ) } i = 1 n + N ,我们把 D i ∈ { 0 , 1 } D_{i} \in \{0,1\} D i ∈ { 0 , 1 } 记为 outcome 是否缺失的 indicators。
除非 Y ^ \hat Y Y ^ 与 Y Y Y 是完美相关的,否则我们用 Y ^ \hat Y Y ^ 来替代 Y Y Y 就不满足 Z Z Z -estimator 的估计方程。Robins et al. [1994] 考虑过一个类似的半参数模型,在假设完全随机缺失 (MCAR) 的条件下,即:
P ( D = 1 ∣ Y , Y ^ , X ) = p for some p ∈ ( 0 , 1 ) 。 \mathbb{P}(D=1 \mid Y, \hat Y, X) = p \quad \text{for some } p \in (0, 1)。 P ( D = 1 ∣ Y , Y ^ , X ) = p for some p ∈ ( 0 , 1 ) 。
他们提出的估计量 θ ^ \hat \theta θ ^ 定义为如下修改后的估计方程的解:
∑ i = 1 n + N ( D i p U θ ( X i , Y i ) − D i − p p ψ θ ( X i , Y ^ i ) ) = 0 , (2) \sum_{i=1}^{n+N} \left( \frac{D_i}{p} U_\theta(X_i, Y_i) - \frac{D_i - p}{p} \psi_\theta(X_i, \hat{Y}_i) \right) = 0, \tag{2} i = 1 ∑ n + N ( p D i U θ ( X i , Y i ) − p D i − p ψ θ ( X i , Y ^ i ) ) = 0 , ( 2 )
其中 ψ θ \psi_\theta ψ θ 是用户指定的函数。Robins 等人证明出其最优选择需要满足:
ψ θ ∗ ( X , Y ^ ) = E [ U θ ( X , Y ) ∣ X , Y ^ ] 。 \psi_\theta^*(X, \hat{Y}) = \mathbb{E}[U_\theta(X, Y) \mid X, \hat{Y}]。 ψ θ ∗ ( X , Y ^ ) = E [ U θ ( X , Y ) ∣ X , Y ^ ] 。
由此得到的估计量 θ ^ \hat \theta θ ^ 是半参数有效 的。
第一项系数 D i / p {D_i}/{p} D i / p :这是逆概率加权 (Inverse Probability Weighting, IPW) 的体现。在MCAR的假设下,每个样本被观测到的概率均为 p p p 。因此,当我们仅使用观测样本 (D i = 1 D_i=1 D i = 1 ) 时,样本的代表性会因缺失而降低。通过将观测样本的贡献放大 1 / p 1/p 1/ p 倍,我们可以“重建”出与完整数据集期望相等的估计方程,从而保证对 θ ⋆ \theta^\star θ ⋆ 的无偏估计。
第二项系数 ( D i − p ) / p {(D_i - p)}/{p} ( D i − p ) / p :这是为了降低估计量方差 而引入的控制变量项。注意到 E [ D i ] = p \mathbb{E}[D_i] = p E [ D i ] = p ,因此该项的期望为零,即 E [ ( D i − p ) / p ] = 0 \mathbb{E}[{(D_i - p)}/{p}] = 0 E [ ( D i − p ) / p ] = 0 ,这意味着它的引入不会改变估计方程的无偏性。然而,通过与第一项中与 ψ θ \psi_\theta ψ θ 相关的部分做减法,可以抵消掉部分随机噪声,从而获得比仅用逆概率加权更小的渐近方差。
尽管最优选择 ψ θ ∗ \psi_\theta^* ψ θ ∗ 能带来渐近有效性,但在实际应用中,我们往往很难准确指定参数化的 ψ θ \psi_\theta ψ θ 。如果选择的 ψ θ \psi_\theta ψ θ 存在严重设定错误,直接用 (2) 式构造的估计量可能比仅使用完整数据(即 D i = 1 D_i=1 D i = 1 的样本)的 X Y XY X Y -only 估计量还要差。后续存在一系列衍生的工作,这里不一一做介绍了。
2. Prediction-Powered Inference (PPI)
Angelopoulos et al. [2023a] 提出了一种将黑盒机器学习模型的预测融入统计推断的方法,称为 Prediction-Powered Inference 。
在 PPI 的设置中,研究者拥有两个数据集:一个有标签,另一个无标签。为了符号上的方便,我们将协变量分为两类:用 X X X 表示定义推断问题的协变量(如公式 (1) 所示),用 W W W 表示可用于预测的、可能是非结构化的高维额外协变量(例如文本或图像)。因此,研究者拥有 n n n 个i.i.d.的有标签数据点 { ( Y i , X i , W i ) } i = 1 n \{(Y_i, X_i, W_i)\}_{i=1}^n {( Y i , X i , W i ) } i = 1 n ,以及 N N N 个 i.i.d. 的无标签数据点 { ( X i , W i ) } i = n + 1 n + N \{(X_i, W_i)\}_{i=n+1}^{n+N} {( X i , W i ) } i = n + 1 n + N 。PPI 的工作假设是 ( X i , W i ) (X_i, W_i) ( X i , W i ) 在两个数据集中的分布是相同的。
目标参数定义为:
θ ⋆ = argmin θ ∈ Θ E [ ℓ θ ( X , Y ) ] , (3) \theta^\star = \operatorname*{argmin}_{\theta \in \Theta} \mathbb{E}[\ell_\theta(X, Y)], \tag{3} θ ⋆ = θ ∈ Θ argmin E [ ℓ θ ( X , Y )] , ( 3 )
这里的估计我们是按照M M M -estimator来进行的,其中我们要求 ℓ θ \ell_\theta ℓ θ 是定义在 θ ∈ R d \theta \in \mathbb{R}^d θ ∈ R d 上的凸损失函数。如果我们取 U θ = ∇ ℓ θ U_\theta = \nabla \ell_\theta U θ = ∇ ℓ θ ,则此目标等价于估计方程目标(公式 (1))。
研究者还可以访问一个黑盒机器学习模型 f f f ,该模型输出预测值 Y ^ = f ( X , W ) \hat{Y} = f(X, W) Y ^ = f ( X , W ) 。该模型不需要同时将 X X X 和 W W W 作为输入;通常情况下,它可能只使用 W W W ,例如在文本标注或图像分类的情境中。借助机器学习模型,研究者可以用预测值 Y ^ \hat{Y} Y ^ 来扩充有标签和无标签数据集。很明显,这恢复了与 surrogate outcome model 相同的问题结构——研究者拥有一个不完整的数据集 { ( Y i , Y ^ i , X i , W i , D i ) } i = 1 n + N \{(Y_i, \hat{Y}_i, X_i, W_i, D_i)\}_{i=1}^{n+N} {( Y i , Y ^ i , X i , W i , D i ) } i = 1 n + N ,其中 Y i Y_i Y i 被观测当且仅当 D i = 1 D_i=1 D i = 1 。
为了方便理解,在这里以现实中的肺癌数据举例子
X: 病人年龄和每天抽抽烟的数量(我们关心的因变量)
Y: 肺部肿瘤的真实大小(结果量少,可以通过开胸手术精确确定)
W: CT片子(高维,允许和X强相关)
Y ^ \hat Y Y ^ : 通过某个模型对W预测得到的肿瘤大小
感觉最近 PPI 文献中的各类估计量都可以写成统一的形式。我们将以下统一公式称为PPI 估计量:
θ ^ g PPI = argmin θ 1 n ∑ i = 1 n ℓ θ ( X i , Y i ) − ( 1 n ∑ i = 1 n g θ ( X i , Y ^ i ) − 1 N ∑ i = n + 1 n + N g θ ( X i , Y ^ i ) ) , (4) \hat{\theta}_g^{\text{PPI}} = \operatorname*{argmin}_\theta \frac{1}{n} \sum_{i=1}^n \ell_\theta(X_i, Y_i) - \left( \frac{1}{n} \sum_{i=1}^n g_\theta(X_i, \hat{Y}_i) - \frac{1}{N} \sum_{i=n+1}^{n+N} g_\theta(X_i, \hat{Y}_i) \right), \tag{4} θ ^ g PPI = θ argmin n 1 i = 1 ∑ n ℓ θ ( X i , Y i ) − ( n 1 i = 1 ∑ n g θ ( X i , Y ^ i ) − N 1 i = n + 1 ∑ n + N g θ ( X i , Y ^ i ) ) , ( 4 )
这里,g θ g_\theta g θ 是一个与方法相关的函数,我们称之为 imputed loss 。如果 ℓ θ \ell_\theta ℓ θ 是凸的,则公式 (4) 本质上等价于公式 (2),其中 ψ θ ( X , Y ^ ) = ∇ g θ ( X , Y ^ ) \psi_\theta(X, \hat{Y}) = \nabla g_\theta(X, \hat{Y}) ψ θ ( X , Y ^ ) = ∇ g θ ( X , Y ^ ) 。
PPI 文献中的现有估计量主要区别在于 g θ g_\theta g θ 的不同选择:
X Y XY X Y -only 估计量 (θ ^ X Y -only \hat{\theta}^{XY\text{-only}} θ ^ X Y -only ):是公式 (4) 的一个特例,其中 g θ ( X , Y ^ ) = 0 g_\theta(X, \hat{Y}) = 0 g θ ( X , Y ^ ) = 0 。该估计量忽略了预测值,仅使用观测到 Y Y Y 的子集数据。
标准 PPI 估计量 (θ ^ PPI \hat{\theta}^{\text{PPI}} θ ^ PPI ):选择 g θ = ℓ θ g_\theta = \ell_\theta g θ = ℓ θ 。
PPI+或相关 power tuning 方法 ∇ g θ = M ^ ∇ ℓ θ ( ⋅ , Y ^ ) \nabla g_\theta=\hat M\nabla \ell_\theta(\cdot,\hat Y) ∇ g θ = M ^ ∇ ℓ θ ( ⋅ , Y ^ )
然而作者指出,这些选择通常不是最优的。根本原因是:最优控制变量不一定是 ∇ ℓ θ ( X , Y ^ ) \nabla \ell_\theta(X,\hat Y) ∇ ℓ θ ( X , Y ^ ) 的线性变换,而应该逼近真实 score 在 ( X , Y ^ ) (X,\hat Y) ( X , Y ^ ) 下的条件期望。
Example: 标准PPI在均值估计中的特例
当 g θ = ℓ θ g_\theta = \ell_\theta g θ = ℓ θ 时,PPI 估计量变为:
θ ^ g PPI = argmin θ 1 N ∑ i = n + 1 n + N ℓ θ ( X i , Y ^ i ) − 1 n ∑ i = 1 n ( ℓ θ ( X i , Y ^ i ) − ℓ θ ( X i , Y i ) ) \hat{\theta}_g^{\text{PPI}} = \operatorname*{argmin}_\theta \frac{1}{N} \sum_{i=n+1}^{n+N} \ell_\theta(X_i, \hat Y_i) - \frac{1}{n} \sum_{i=1}^n \left(\ell_\theta(X_i, \hat Y_i) - \ell_\theta(X_i, Y_i)\right) θ ^ g PPI = θ argmin N 1 i = n + 1 ∑ n + N ℓ θ ( X i , Y ^ i ) − n 1 i = 1 ∑ n ( ℓ θ ( X i , Y ^ i ) − ℓ θ ( X i , Y i ) )
假设我需要估计 Y Y Y 的均值,损失考虑均方损失 ℓ θ ( Y ) = 1 2 ( Y − θ ) 2 \ell_\theta(Y) = \frac{1}{2}(Y-\theta)^2 ℓ θ ( Y ) = 2 1 ( Y − θ ) 2 ,有:
J ( θ ) = 1 N ∑ i = n + 1 n + N 1 2 ( Y ^ i − θ ) 2 − 1 n ∑ i = 1 n ( 1 2 ( Y ^ i − θ ) 2 − 1 2 ( Y i − θ ) 2 ) J(\theta) = \frac{1}{N}\sum_{i=n+1}^{n+N} \frac{1}{2}(\hat Y_i-\theta)^2 - \frac{1}{n} \sum_{i=1}^{n} \left(\frac{1}{2}(\hat Y_i-\theta)^2 - \frac{1}{2}(Y_i-\theta)^2\right) J ( θ ) = N 1 i = n + 1 ∑ n + N 2 1 ( Y ^ i − θ ) 2 − n 1 i = 1 ∑ n ( 2 1 ( Y ^ i − θ ) 2 − 2 1 ( Y i − θ ) 2 )
由于 J ( θ ) J(\theta) J ( θ ) 对 θ \theta θ 可导,有:
J ′ ( θ ) = − 1 N ∑ i = n + 1 n + N ( Y ^ i − θ ) − 1 n ∑ i = 1 n ( Y i − Y ^ i ) = 0 J^\prime(\theta) = -\frac{1}{N} \sum_{i=n+1}^{n+N} (\hat Y_i-\theta) - \frac{1}{n} \sum_{i=1}^n (Y_i-\hat Y_i) = 0 J ′ ( θ ) = − N 1 i = n + 1 ∑ n + N ( Y ^ i − θ ) − n 1 i = 1 ∑ n ( Y i − Y ^ i ) = 0
因此有 θ ^ PPI = 1 N ∑ i = n + 1 n + N Y ^ i − 1 n ∑ i = 1 n ( Y ^ i − Y i ) \hat \theta^{\text{PPI}} = \frac{1}{N} \sum_{i=n+1}^{n+N} \hat Y_i - \frac{1}{n}\sum_{i=1}^n(\hat Y_i-Y_i) θ ^ PPI = N 1 ∑ i = n + 1 n + N Y ^ i − n 1 ∑ i = 1 n ( Y ^ i − Y i ) 。结果比较符合直觉,第一项利用了大量无标签数据的预测均值来降低方差,第二项用有标签数据的预测偏差对均值做修正。
虽然上述选择(如标准 PPI 或 PPI++)通常保证比 X Y XY X Y -only 更高效,但作者指出,除了极少数特殊情况外,现有的 PPI 估计量都未能达到公式 (4) 定义下的最低渐近方差。
3. 最优 imputed loss
回到 Robins et al. [1994] 的工作,早期关于替代结果模型的经典理论已经指明了最优选择的方向。在 PPI 的设定下,最优的 imputed loss 的梯度应当满足:
∇ g θ ⋆ ∗ ( X , Y ^ ) = N n + N E [ ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ ] . (5) \nabla g_{\theta^\star}^*(X, \hat Y) = \frac{N}{n+N} \mathbb{E}[\nabla \ell_{\theta^\star}(X, Y) \mid X, \hat Y]. \tag{5} ∇ g θ ⋆ ∗ ( X , Y ^ ) = n + N N E [ ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ ] . ( 5 )
定理 1 (最优性条件) :
令 θ ⋆ \theta^\star θ ⋆ 是公式 (3) 的唯一解,并假设 n / N → r n/N \to r n / N → r 。在正则条件下,n ( θ ^ g PPI − θ ⋆ ) → d N ( 0 , Σ g PPI ) \sqrt{n}(\hat{\theta}_g^{\text{PPI}} - \theta^\star) \xrightarrow{d} \mathcal{N}(0, \Sigma_g^{\text{PPI}}) n ( θ ^ g PPI − θ ⋆ ) d N ( 0 , Σ g PPI ) 。如果 g θ g_\theta g θ 满足:
∇ g θ ⋆ ( X , Y ^ ) = 1 1 + r s ⋆ ( X , Y ^ ) , 其中 s ⋆ ( X , Y ^ ) = E [ ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ ] , (7) \nabla g_{\theta^\star}(X, \hat Y) = \frac{1}{1+r} s^\star(X, \hat Y), \quad \text{其中 } s^\star(X, \hat Y) = \mathbb{E}[\nabla \ell_{\theta^\star}(X, Y) \mid X, \hat Y], \tag{7} ∇ g θ ⋆ ( X , Y ^ ) = 1 + r 1 s ⋆ ( X , Y ^ ) , 其中 s ⋆ ( X , Y ^ ) = E [ ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ ] , ( 7 )
则该估计量的渐近方差将达到最小。
结论 :最优的 imputed loss 不是唯一的,我们只需要它的梯度在 θ ⋆ \theta^\star θ ⋆ 处满足上述形式。这为我们选择更方便的 g θ g_\theta g θ 提供了自由度。
需要注意的是,直接使用公式 (5) 来计算 g θ g_\theta g θ 可能非常困难。为了在计算上保持优化目标 (4) 的凸性(只要 ℓ θ \ell_\theta ℓ θ 是凸的),作者选择了一个更方便的线性形式:
g θ ( X , Y ^ ) = 1 1 + r θ ⊤ s ⋆ ( X , Y ^ ) . (8) g_\theta(X, \hat Y) = \frac{1}{1+r} \theta^\top s^\star(X, \hat Y). \tag{8} g θ ( X , Y ^ ) = 1 + r 1 θ ⊤ s ⋆ ( X , Y ^ ) . ( 8 )
在这个选择下,目标函数 (4) 只是对标准的经验损失加上了一个线性项,因此凸性得以保持。这极大地简化了优化过程。
例子 1:广义线性模型
假设 ∇ ℓ θ ( X , Y ) = X ( μ ( X ⊤ θ ) − Y ) \nabla \ell_\theta(X, Y) = X(\mu(X^\top \theta) - Y) ∇ ℓ θ ( X , Y ) = X ( μ ( X ⊤ θ ) − Y ) ,则最优的 s ⋆ s^\star s ⋆ 变为:
s ⋆ ( X , Y ^ ) = X ( μ ( X ⊤ θ ⋆ ) − E [ Y ∣ X , Y ^ ] ) = ∇ ℓ θ ⋆ ( X , E [ Y ∣ X , Y ^ ] ) . s^\star(X, \hat Y) = X \left(\mu(X^\top \theta^\star) - \mathbb{E}[Y \mid X, \hat Y]\right) = \nabla \ell_{\theta^\star}(X, \mathbb{E}[Y \mid X, \hat Y]). s ⋆ ( X , Y ^ ) = X ( μ ( X ⊤ θ ⋆ ) − E [ Y ∣ X , Y ^ ] ) = ∇ ℓ θ ⋆ ( X , E [ Y ∣ X , Y ^ ]) .
校准预测 (Calibrated Predictions) :如果预测 Y ^ \hat Y Y ^ 是校准的,即 Y ^ = E [ Y ∣ X ] \hat Y = \mathbb{E}[Y \mid X] Y ^ = E [ Y ∣ X ] ,那么 E [ Y ∣ X , Y ^ ] = Y ^ \mathbb{E}[Y \mid X, \hat Y] = \hat Y E [ Y ∣ X , Y ^ ] = Y ^ 。在这种情况下,标准 PPI (g θ = ℓ θ g_\theta = \ell_\theta g θ = ℓ θ ) 就是最优的。
一般情况 :通常 E [ Y ∣ X , Y ^ ] \mathbb{E}[Y \mid X, \hat Y] E [ Y ∣ X , Y ^ ] 可以看作是一个重校准的预测 。这就是 RePPI(Recalibrated PPI)方法的理论基础。
例子 2:分位数回归
对于分位数回归,∇ ℓ θ ( X , Y ) = X ( τ − I ( Y ≤ X ⊤ θ ) ) \nabla \ell_\theta(X, Y) = X(\tau - \mathbb{I}(Y \le X^\top \theta)) ∇ ℓ θ ( X , Y ) = X ( τ − I ( Y ≤ X ⊤ θ )) 。此时 s ⋆ s^\star s ⋆ 的形式为:
s ⋆ ( X , Y ^ ) = X ( τ − P ( Y ≤ X ⊤ θ ⋆ ∣ X , Y ^ ) ) . s^\star(X, \hat Y) = X \left(\tau - \mathbb{P}(Y \le X^\top \theta^\star \mid X, \hat Y)\right). s ⋆ ( X , Y ^ ) = X ( τ − P ( Y ≤ X ⊤ θ ⋆ ∣ X , Y ^ ) ) .
由于涉及到指示函数 I \mathbb{I} I ,这里 s ⋆ s^\star s ⋆ 无法直接通过 ∇ ℓ θ ⋆ \nabla \ell_{\theta^\star} ∇ ℓ θ ⋆ 表示,必须显式地估计条件概率。
4. RePPI 算法
直接估计 s ⋆ ( X , Y ^ ) s^\star(X,\hat Y) s ⋆ ( X , Y ^ ) 面临两个主要困难:一是 θ ⋆ \theta^\star θ ⋆ 未知,二是条件期望 E [ ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ ] \mathbb{E}[\nabla \ell_{\theta^\star}(X, Y) \mid X, \hat Y] E [ ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ ] 可能非常复杂,难以用简单的参数形式建模。
为了解决这些问题,作者提出了一种更为实用的方法,称为 重校准 PPI (Recalibrated PPI, RePPI) 。其核心思想可以分为三个阶段:
初始估计 :先通过 X Y XY X Y -only 方法获得一个一致但未必有效的初始估计量 θ ^ 0 \hat{\theta}_0 θ ^ 0 。
灵活学习 :使用灵活的机器学习方法(如随机森林或梯度提升)在标签数据上学习条件期望 s ^ ( X , Y ^ ) ≈ E [ ∇ ℓ θ ^ 0 ( X , Y ) ∣ X , Y ^ ] \hat{s}(X, \hat Y) \approx \mathbb{E}[\nabla \ell_{\hat{\theta}_0}(X, Y) \mid X, \hat Y] s ^ ( X , Y ^ ) ≈ E [ ∇ ℓ θ ^ 0 ( X , Y ) ∣ X , Y ^ ] 。在这里我们学到的函数不一定是简单的线性问题。
安全缩放 :为了对抗 s ^ \hat{s} s ^ 可能出现的估计误差,引入一个缩放矩阵 M ^ \hat M M ^ 来调节 s ^ \hat{s} s ^ 的贡献,确保效率不低于 X Y XY X Y -only。
为了减少嵌套估计带来的偏差,作者采用了 三折交叉拟合 (Cross-Fitting) 。完整的算法流程如下:
Algorithm 1: Recalibrated Prediction-Powered Inference (RePPI)
Step 1 : 随机将标签数据集分成三个均等的子集 D 1 , D 2 , D 3 D_1, D_2, D_3 D 1 , D 2 , D 3 。
Step 2 : 在 D 1 D_1 D 1 上计算初始估计量 θ ^ 0 1 = argmin θ 1 ∣ D 1 ∣ ∑ i ∈ D 1 ℓ θ ( X i , Y i ) \hat{\theta}_0^1 = \operatorname*{argmin}_{\theta} \frac{1}{|D_1|} \sum_{i \in D_1} \ell_\theta(X_i, Y_i) θ ^ 0 1 = argmin θ ∣ D 1 ∣ 1 ∑ i ∈ D 1 ℓ θ ( X i , Y i ) 。
Step 3 : 在 D 2 D_2 D 2 上使用灵活的机器学习方法(如随机森林)估计条件期望,得到 s ^ ( X , Y ^ ) ≈ E [ ∇ ℓ θ ^ 0 1 ( X , Y ) ∣ X , Y ^ ] \hat{s}(X, \hat Y) \approx \mathbb{E}[\nabla \ell_{\hat{\theta}_0^1}(X, Y) \mid X, \hat Y] s ^ ( X , Y ^ ) ≈ E [ ∇ ℓ θ ^ 0 1 ( X , Y ) ∣ X , Y ^ ] 。
Step 4 : 在 D 3 D_3 D 3 上计算缩放矩阵 M ^ \hat M M ^ ,其定义如下(基于真实梯度与估计梯度的协方差关系):
M ^ = Cov ^ ( ∇ ℓ θ ^ 0 1 ( X , Y ) , s ^ ( X , Y ^ ) ) Cov ^ ( s ^ ( X , Y ^ ) ) − 1 . (9) \hat M = \widehat{\text{Cov}}(\nabla \ell_{\hat{\theta}_0^1}(X, Y), \hat{s}(X, \hat Y)) \; \widehat{\text{Cov}}(\hat{s}(X, \hat Y))^{-1}. \tag{9} M ^ = Cov ( ∇ ℓ θ ^ 0 1 ( X , Y ) , s ^ ( X , Y ^ )) Cov ( s ^ ( X , Y ^ ) ) − 1 . ( 9 )
Step 5 : 在 D 3 D_3 D 3 和无标签数据上,构造最终的目标函数并求解。这里,插补损失 g θ g_\theta g θ 的梯度被设计为:
∇ g θ ( X , Y ^ ) = 1 1 + n / N M ^ s ^ ( X , Y ^ ) . \nabla g_{\theta}(X, \hat Y) = \frac{1}{1 + n/N} \hat M \hat{s}(X, \hat Y). ∇ g θ ( X , Y ^ ) = 1 + n / N 1 M ^ s ^ ( X , Y ^ ) .
这样构造的估计量记为 θ ^ 1 \hat{\theta}^1 θ ^ 1 。
Step 6 : 重复 Step 2-5,进行折轮换:( D 3 , D 1 , D 2 ) (D_3, D_1, D_2) ( D 3 , D 1 , D 2 ) 和 ( D 2 , D 3 , D 1 ) (D_2, D_3, D_1) ( D 2 , D 3 , D 1 ) ,分别得到估计量 θ ^ 2 \hat{\theta}^2 θ ^ 2 和 θ ^ 3 \hat{\theta}^3 θ ^ 3 。
Step 7 : 计算最终的交叉拟合估计量:
θ ^ CrossFit = ∣ D 1 ∣ n θ ^ 1 + ∣ D 2 ∣ n θ ^ 2 + ∣ D 3 ∣ n θ ^ 3 。 \hat{\theta}^{\text{CrossFit}} = \frac{|D_1|}{n} \hat{\theta}^1 + \frac{|D_2|}{n} \hat{\theta}^2 + \frac{|D_3|}{n} \hat{\theta}^3。 θ ^ CrossFit = n ∣ D 1 ∣ θ ^ 1 + n ∣ D 2 ∣ θ ^ 2 + n ∣ D 3 ∣ θ ^ 3 。
4.1 缩放矩阵 M ^ \hat M M ^ 的作用与理论推导
为什么需要 M ^ \hat M M ^ ?如果 s ^ \hat{s} s ^ 能够完美地估计 s ⋆ s^\star s ⋆ ,那么直接使用 s ^ \hat{s} s ^ 就已经是最优的了。但在实践中,s ^ \hat{s} s ^ 可能会因为模型设定错误或样本量不足而存在偏差。如果直接使用有偏的 s ^ \hat{s} s ^ ,最终的估计量方差甚至可能 大于 仅使用 X Y XY X Y -only 的方差。
M ^ \hat M M ^ 的本质是 最优控制变量的线性投影系数 ,它通过将估计的梯度 s ^ \hat{s} s ^ 投影到真实梯度 ∇ ℓ \nabla \ell ∇ ℓ 的方向上,来最小化最终估计量的渐近方差。
理论推导
让我们从一种通用视角来看作者在公式 (9) 中给出的 M ^ \hat M M ^ 是如何确定的。在 RePPI 的最后构造中,我们利用 M ^ s ^ \hat M \hat{s} M ^ s ^ 来近似真实梯度 ∇ ℓ θ ⋆ \nabla \ell_{\theta^\star} ∇ ℓ θ ⋆ 。为了找到最优的 M ^ \hat M M ^ ,我们求解以下最小二乘问题:
M ∗ = arg min M E [ ∥ ∇ ℓ θ ⋆ ( X , Y ) − M s ^ ( X , Y ^ ) ∥ 2 ] 。 M^* = \arg\min_{M} \mathbb{E}\left[ \| \nabla \ell_{\theta^\star}(X,Y) - M \hat{s}(X,\hat Y) \|^2 \right]。 M ∗ = arg M min E [ ∥∇ ℓ θ ⋆ ( X , Y ) − M s ^ ( X , Y ^ ) ∥ 2 ] 。
对于矩阵 M M M 的优化,我们展开目标函数。注意到 ∥ A ∥ 2 = T r ( A A ⊤ ) \|A\|^2 = \mathrm{Tr}(AA^\top) ∥ A ∥ 2 = Tr ( A A ⊤ ) ,利用期望的线性性质和迹的性质,目标函数可以写为:
L ( M ) = E [ T r ( ( ∇ ℓ − M s ^ ) ( ∇ ℓ − M s ^ ) ⊤ ) ] = T r ( E [ ∇ ℓ ∇ ℓ ⊤ ] ) − 2 T r ( M E [ s ^ ∇ ℓ ⊤ ] ) + T r ( M E [ s ^ s ^ ⊤ ] M ⊤ ) 。 \begin{aligned}
\mathcal{L}(M) &= \mathbb{E}\left[ \mathrm{Tr}\left( (\nabla \ell - M\hat{s})(\nabla \ell - M\hat{s})^\top \right) \right] \\
&= \mathrm{Tr}\left( \mathbb{E}[\nabla \ell \nabla \ell^\top] \right) - 2 \mathrm{Tr}\left( M \mathbb{E}[\hat{s} \nabla \ell^\top] \right) + \mathrm{Tr}\left( M \mathbb{E}[\hat{s} \hat{s}^\top] M^\top \right)。
\end{aligned} L ( M ) = E [ Tr ( ( ∇ ℓ − M s ^ ) ( ∇ ℓ − M s ^ ) ⊤ ) ] = Tr ( E [ ∇ ℓ ∇ ℓ ⊤ ] ) − 2 Tr ( M E [ s ^ ∇ ℓ ⊤ ] ) + Tr ( M E [ s ^ s ^ ⊤ ] M ⊤ ) 。
对矩阵 M M M 求梯度并令其为零:
∂ L ∂ M = − 2 E [ ∇ ℓ s ^ ⊤ ] + 2 M E [ s ^ s ^ ⊤ ] = 0 。 \frac{\partial \mathcal{L}}{\partial M} = -2 \mathbb{E}[\nabla \ell \hat{s}^\top] + 2 M \mathbb{E}[\hat{s} \hat{s}^\top] = 0。 ∂ M ∂ L = − 2 E [ ∇ ℓ s ^ ⊤ ] + 2 M E [ s ^ s ^ ⊤ ] = 0 。
因此,最优的线性投影矩阵为:
M ∗ = E [ ∇ ℓ s ^ ⊤ ] ( E [ s ^ s ^ ⊤ ] ) − 1 。 (9*) M^* = \mathbb{E}[\nabla \ell \hat{s}^\top] \left( \mathbb{E}[\hat{s} \hat{s}^\top] \right)^{-1}。 \tag{9*} M ∗ = E [ ∇ ℓ s ^ ⊤ ] ( E [ s ^ s ^ ⊤ ] ) − 1 。 ( 9* )
在 RePPI 的实际操作中,我们使用中心化后的样本协方差矩阵来估计上式。这正是作者论文中给出的估计量:
M ^ = Cov ^ ( ∇ ℓ θ ^ 0 , s ^ ) Cov ^ ( s ^ ) − 1 。 (9) \hat M = \widehat{\text{Cov}}(\nabla \ell_{\hat{\theta}_0}, \hat{s}) \; \widehat{\text{Cov}}(\hat{s})^{-1}。 \tag{9} M ^ = Cov ( ∇ ℓ θ ^ 0 , s ^ ) Cov ( s ^ ) − 1 。 ( 9 )
为什么这个选择能保证“安全”?
从统计学角度看,M ^ s ^ \hat M \hat{s} M ^ s ^ 构成了 ∇ ℓ \nabla \ell ∇ ℓ 在 s ^ \hat{s} s ^ 张成的线性子空间上的正交投影 。根据投影定理,残余项 ∇ ℓ − M ^ s ^ \nabla \ell - \hat M \hat{s} ∇ ℓ − M ^ s ^ 将与 s ^ \hat{s} s ^ 正交。这意味着,无论 s ^ \hat{s} s ^ 的质量如何,将其作为控制变量引入时:
完全准确时 :如果 s ^ \hat{s} s ^ 与真实梯度 ∇ ℓ \nabla \ell ∇ ℓ 完全相关,M ^ \hat M M ^ 将自动缩放 s ^ \hat{s} s ^ 以精确匹配真实梯度,使得 RePPI 的方差达到理论最小值(定理1)。
完全不准确时 :如果 s ^ \hat{s} s ^ 与真实梯度 ∇ ℓ \nabla \ell ∇ ℓ 完全不相关(即全是噪音),协方差矩阵 Cov ^ ( ∇ ℓ , s ^ ) \widehat{\text{Cov}}(\nabla \ell, \hat{s}) Cov ( ∇ ℓ , s ^ ) 将趋近于零,因此 M ^ \hat M M ^ 也会收敛到零矩阵。此时,M ^ s ^ ≈ 0 \hat M \hat{s} \approx 0 M ^ s ^ ≈ 0 ,RePPI 自动退化回 X Y XY X Y -only 估计量,确保效率不会比基准方法更差。
这种设计完美实现了“在充分利用 s ^ \hat{s} s ^ 提升效率的同时,永远不会被 s ^ \hat{s} s ^ 的误差拖累”的安全机制。
4.2 理论保证
定理 2 (RePPI 的渐近性质) :
假设 E [ ∥ s ^ ( X , Y ^ ) − s ( X , Y ^ ) ∥ 2 ] → 0 \mathbb{E}[\|\hat{s}(X, \hat Y) - s(X, \hat Y)\|^2] \to 0 E [ ∥ s ^ ( X , Y ^ ) − s ( X , Y ^ ) ∥ 2 ] → 0 对某个函数 s s s 成立,则在正则条件下:
n ( θ ^ CrossFit − θ ⋆ ) → d N ( 0 , Σ s RePPI ) , \sqrt{n}(\hat{\theta}^{\text{CrossFit}} - \theta^\star) \xrightarrow{d} \mathcal{N}(0, \Sigma_s^{\text{RePPI}}), n ( θ ^ CrossFit − θ ⋆ ) d N ( 0 , Σ s RePPI ) ,
其中渐近方差为:
Σ s RePPI = H θ ⋆ − 1 ( Cov ( ∇ ℓ θ ⋆ ( X , Y ) ) − Δ ) H θ ⋆ − 1 , \Sigma_s^{\text{RePPI}} = H_{\theta^\star}^{-1} \left( \text{Cov}(\nabla \ell_{\theta^\star}(X, Y)) - \Delta \right) H_{\theta^\star}^{-1}, Σ s RePPI = H θ ⋆ − 1 ( Cov ( ∇ ℓ θ ⋆ ( X , Y )) − Δ ) H θ ⋆ − 1 ,
且 Δ = 1 1 + r Cov ( ∇ ℓ θ ⋆ , s ) Cov ( s ) − 1 Cov ( s , ∇ ℓ θ ⋆ ) \Delta = \frac{1}{1+r} \text{Cov}(\nabla \ell_{\theta^\star}, s) \text{Cov}(s)^{-1} \text{Cov}(s, \nabla \ell_{\theta^\star}) Δ = 1 + r 1 Cov ( ∇ ℓ θ ⋆ , s ) Cov ( s ) − 1 Cov ( s , ∇ ℓ θ ⋆ ) 。
定理2揭示了 RePPI 的两个关键性质:
一致性 :如果 s s s 被一致地估计为 s ⋆ s^\star s ⋆ ,即 s ^ → s ⋆ \hat{s} \to s^\star s ^ → s ⋆ ,则 RePPI 的方差正好等于定理1中定义的 最优 PPI 方差 ,达到了理论效率的极限。
鲁棒性 :即使估计的 s s s 不是最优的(即 s ≠ s ⋆ s \neq s^\star s = s ⋆ ),由于 Δ \Delta Δ 是半正定的,RePPI 的方差仍然 总是小于或等于 X Y XY X Y -only 的方差。这意味着 RePPI 永远不会比不使用预测的方法更差。
总结 :RePPI 通过引入 M ^ \hat M M ^ 和交叉拟合,同时实现了统计效率的理论极限 和算法的鲁棒性 。与传统方法相比,它无需指定 g θ g_\theta g θ 的精确参数形式,而是通过机器学习灵活逼近最优条件期望,使得其在现代 AI 场景下具有广泛的应用价值。
5. 实验结果
作者在三个真实数据集上比较 XY-only、PPI、PPI++ 和 RePPI。每个实验都从完整标注数据里随机抽出一部分作为有标签样本,其余当作无标签样本,并用完整数据上的 estimand 作为 θ ⋆ \theta^\star θ ⋆ 的 proxy。评价指标是 90% 置信区间的平均宽度和覆盖率。
US Census
目标是估计 log-income 对 age 回归的 age coefficient。作者用 XGBoost 基于 14 个其他协变量预测 log-income,包括 education、marital status、citizenship、race 等。为了制造 distribution shift,预测模型只在 college degree or above 人群上训练,但推断目标是整个人群。
实验数据共 377,575 个观测,标注比例从 2% 到 10%。结果是:各方法覆盖率大致正确,但 RePPI 的区间更短。达到同样区间长度时,RePPI 相比 PPI++ 节省约 24% 到 26% 的标签。
Politeness of Online Requests
数据来自 Stack Exchange 和 Wikipedia 上的 5512 条在线请求。真实 politeness score 是 5 名人类评估者在 1 到 25 尺度上的平均分。作者用 GPT-4o mini 对文本礼貌程度打分,目标参数是 politeness score 对 hedging indicator 回归的系数。
这个任务里,LLM 给出的礼貌分数与人类评分分布不完全对齐。RePPI 通过校准 imputed loss 减小置信区间宽度。达到给定区间长度时,相比 PPI++ 节省约 5% 到 7% 的标签。
Wine Reviews
数据来自 WineEnthusiast wine review。真实评分在 80 到 100 之间,作者用 GPT-4o mini 根据品酒评论文本预测评分。目标参数是控制 wine region 后,rating 对 price 的回归系数。作者从美国葡萄酒样本中抽取 10,000 个观测做实验。
结果同样显示 RePPI 在不同标注比例下都优于其他方法。达到同样区间长度时,相比 PPI++ 节省约 16% 到 20% 的标签。作者给出的解释是:酒评由单一品酒者打分,主观性较强,通用语言模型的预测分布容易与真实评分尺度错位,因此 recalibration 尤其重要。
8. Summary
这篇文章的核心不是“用 AI 预测代替真实标签”,而是“把 AI 预测当作一个外部训练得到的控制变量”。它不改变目标参数,也不要求黑箱预测无偏;它只要求有标签样本和无标签样本里的 ( X , Y ^ ) (X,\hat Y) ( X , Y ^ ) 同分布,从而让 imputed loss 差分项在总体上为零。
最关键的数学对象是
s ⋆ ( X , Y ^ ) = E { ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ } . s^\star(X,\hat Y)
=
E\{\nabla\ell_{\theta^\star}(X,Y)\mid X,\hat Y\}. s ⋆ ( X , Y ^ ) = E { ∇ ℓ θ ⋆ ( X , Y ) ∣ X , Y ^ } .
如果把 ∇ ℓ θ ⋆ ( X , Y ) \nabla\ell_{\theta^\star}(X,Y) ∇ ℓ θ ⋆ ( X , Y ) 看作目标估计方程的噪声来源,那么 s ⋆ s^\star s ⋆ 就是可以由预测值解释掉的那一部分噪声。RePPI 的全部效率提升,都来自把这部分噪声以控制变量的形式减掉。
这篇文章也提醒了一件很实际的事:预训练模型预测越强,不代表直接 PPI 越好。只要存在 modality mismatch、distribution shift 或 discrete uncalibrated prediction,直接使用 ℓ θ ( X , Y ^ ) \ell_\theta(X,\hat Y) ℓ θ ( X , Y ^ ) 就可能不是好控制变量。统计上真正有用的是校准后的 score,而不是原始预测。
局限也很明确:
方法依赖 labeled 和 unlabeled 数据中 ( X , Y ^ ) (X,\hat Y) ( X , Y ^ ) 的分布一致。
如果 s ^ \hat s s ^ 学得很差,RePPI 仍然不输 XY-only 的一阶方差,但可能无法接近最优效率。
文章主要讨论通过 Y ^ \hat Y Y ^ 使用高维 W W W ;如果研究者能直接建模 E { ∇ ℓ ( X , Y ) ∣ X , W } E\{\nabla\ell(X,Y)\mid X,W\} E { ∇ ℓ ( X , Y ) ∣ X , W } ,理论上还可以更有效。
实验里的 recalibration 多用线性回归;更复杂任务中,如何选择稳定的 s ^ \hat s s ^ 学习器仍然是工程和统计上的关键问题。
重点参考文献
Robins, J. M., Rotnitzky, A., and Zhao, L. P. (1994). Estimation of Regression Coefficients When Some Regressors Are Not Always Observed . Journal of the American Statistical Association, 89(427), 846-866.
Chen, Y.-H., and Chen, H. (2000). A Robust Imputation Method for Surrogate Outcome Data . Biometrika, 87(3), 711-716.
Chen, S. X., Leung, D. H. Y., and Qin, J. (2008). Improving Semiparametric Estimation by Using Surrogate Data . Journal of the Royal Statistical Society: Series B, 70(4), 803-823.
Angelopoulos, A. N., Bates, S., Fannjiang, C., Jordan, M. I., and Zrnic, T. (2023). Prediction-Powered Inference . Science, 382(6671), 669-674.
Angelopoulos, A. N., Duchi, J. C., and Zrnic, T. (2023). PPI++: Efficient Prediction-Powered Inference . arXiv preprint arXiv:2311.01453.
Gronsbell, J., Gao, J., Shi, Y., McCaw, Z. R., and Cheng, D. (2024). Another Look at Inference After Prediction . arXiv preprint arXiv:2411.19908.