论文标题
正确的N-对比:一种改善稳健性的对比方法
Correct-N-Contrast: A Contrastive Approach for Improving Robustness to Spurious Correlations
论文作者
论文摘要
虚假的相关性对强大的机器学习构成了重大挑战。接受经验风险最小化(ERM)训练的模型可能会学会依靠类标签和虚假属性之间的相关性,从而导致没有这些相关性的数据组的性能差。当伪造属性标签不可用时,这尤其具有挑战性。为了提高在没有训练属性标签的情况下在虚假相关数据上的最差性能,我们提出了正确的n-controst(CNC),这是一种直接学习对虚假相关性鲁棒的对比方法。由于ERM模型可以是良好的伪造属性预测指标,因此CNC使用(1)使用训练有素的ERM模型的输出来识别具有同一类但相同的虚假特征的样本,以及(2)训练具有对比度学习的强大模型以学习相同类别样本的相似表示形式。为了支持CNC,我们引入了最严重的组错误与CNC旨在最小化的表示对准损失之间的新连接。我们从经验上观察到,最差的组错误会以对齐方式损失仔细跟踪,并证明一类班级的对齐损失有助于班级的最差群体与平均误差差距上限。在流行的基准测试中,CNC大大减少了对齐损失,并将最新组的最差准确度达到3.6%的平均绝对升力。 CNC还与需要组标签的Oracle方法具有竞争力。
Spurious correlations pose a major challenge for robust machine learning. Models trained with empirical risk minimization (ERM) may learn to rely on correlations between class labels and spurious attributes, leading to poor performance on data groups without these correlations. This is particularly challenging to address when spurious attribute labels are unavailable. To improve worst-group performance on spuriously correlated data without training attribute labels, we propose Correct-N-Contrast (CNC), a contrastive approach to directly learn representations robust to spurious correlations. As ERM models can be good spurious attribute predictors, CNC works by (1) using a trained ERM model's outputs to identify samples with the same class but dissimilar spurious features, and (2) training a robust model with contrastive learning to learn similar representations for same-class samples. To support CNC, we introduce new connections between worst-group error and a representation alignment loss that CNC aims to minimize. We empirically observe that worst-group error closely tracks with alignment loss, and prove that the alignment loss over a class helps upper-bound the class's worst-group vs. average error gap. On popular benchmarks, CNC reduces alignment loss drastically, and achieves state-of-the-art worst-group accuracy by 3.6% average absolute lift. CNC is also competitive with oracle methods that require group labels.