0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

为什么深度学习模型经常出现预测概率和真实情况差异大的问题?

深度学习自然语言处理 来源:圆圆的算法笔记 作者:Fareise 2022-09-09 17:11 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

大家在训练深度学习模型的时候,有没有遇到这样的场景:分类任务的准确率比较高,但是模型输出的预测概率和实际预测准确率存在比较大的差异?这就是现代深度学习模型面临的校准问题。在很多场景中,我们不仅关注分类效果或者排序效果(auc),还希望模型预测的概率也是准的。例如在自动驾驶场景中,如果模型无法以置信度较高的水平检测行人或障碍物,就应该通过输出概率反映出来,并让模型依赖其他信息进行决策。再比如在广告场景中,ctr预测除了给广告排序外,还会用于确定最终的扣费价格,如果ctr的概率预测的不准,会导致广告主的扣费偏高或偏低。

那么,为什么深度学习模型经常出现预测概率和真实情况差异大的问题?又该如何进行校准呢?这篇文章首先给大家介绍模型输出预测概率不可信的原因,再为大家通过10篇顶会论文介绍经典的校准方法,可以适用于非常广泛的场景

1 为什么会出现校准差的问题

最早进行系统性的分析深度学习输出概率偏差问题的是2017年在ICML发表的一篇文章On calibration of modern neural networks(ICML 2017)。文中发现,相比早期的简单神经网络模型,现在的模型越来越大,效果越来越好,但同时模型的校准性越来越差。文中对比了简单模型LeNet和现代模型ResNet的校准情况,LeNet的输出结果校准性很好,而ResNet则出现了比较严重的过自信问题(over-confidence),即模型输出的置信度很高,但实际的准确率并没有那么高。

d29db7d4-2f5d-11ed-ba43-dac502259ad0.png

造成这个现象的最本质原因,是模型对分类问题通常使用的交叉熵损失过拟合。并且模型越复杂,拟合能力越强,越容易过拟合交叉熵损失,带来校准效果变差。这也解释了为什么随着深度学习模型的发展,校准问题越来越凸显出来。

那么为什么过拟合交叉熵损失,就会导致校准问题呢?因为根据交叉熵损失的公式可以看出,即使模型已经在正确类别上的输出概率值最大(也就是分类已经正确了),继续增大对应的概率值仍然能使交叉熵进一步减小。因此模型会倾向于over-confident,即对于样本尽可能的让模型预测为正确的label对应的概率接近1。模型过拟合交叉熵,带来了分类准确率的提升,但是牺牲的是模型输出概率的可信度。

如何解决校准性差的问题,让模型输出可信的概率值呢?业内的主要方法包括后处理和在模型中联合优化校准损失两个方向,下面给大家分别进行介绍。

2 后处理校准方法

后处理校准方法指的是,先正常训练模型得到初始的预测结果,再对这些预测概率值进行后处理,让校准后的预测概率更符合真实情况。典型的方法包括Histogram binning(2001)Isotonic regression(2002)Platt scaling(1999)

Histogram binning是一种比较简单的校准方法,根据初始预测结果进行排序后分桶,每个桶内求解一个校准后的结果,落入这个桶内的预测结果,都会被校准成这个值。每个桶校准值的求解方法是利用一个验证集进行拟合,求解桶内平均误差最小的值,其实也就是落入该桶内正样本的比例。

Isotonic regression是Histogram binning一种扩展,通过学习一个单调增函数,输入初始预测结果,输出校准后的预测结果,利用这个单调增函数最小化预测值和label之间的误差。保序回归就是在不改变预测结果的排序(即不影响模型的排序能力),通过修改每个元素的值让整体的误差最小,进而实现模型纠偏。

Platt scaling则直接使用一个逻辑回归模型学习基础预测值到校准预测值的函数,利用这个函数实现预测结果校准。在获得基础预估结果后,以此作为输入,训练一个逻辑回归模型,拟合校准后的结果,也是在一个单独的验证集上进行训练。这个方法的问题在于对校准前的预测值和真实值之间的关系做了比较强分布假设。

3 在模型中进行校准

除了后处理的校准方法外,一些在模型训练过程中实现校准的方法获得越来越多的关注。在模型中进行校准避免了后处理的两阶段方式,主要包括在损失函数中引入校准项、label smoothing以及数据增强三种方式

基于损失函数的校准方法最基础的是On calibration of modern neural networks(ICML 2017)这篇文章提出的temperature scaling方法。Temperature scaling的实现方式很简单,把模型最后一层输出的logits(softmax的输入)除以一个常数项。这里的temperature起到了对logits缩放的作用,让输出的概率分布熵更大(温度系数越大越接近均匀分布)。同时,这样又不会改变原来预测类别概率值的相对排序,因此理论上不会对模型准确率产生负面影响。

Trainable calibration measures for neural networks from kernel mean embeddings(2018)这篇文章中,作者直接定义了一个可导的校准loss,作为一个辅助loss在模型中和交叉熵loss联合学习。本文定义的MMCE原理来自评估模型校准度的指标,即模型输出类别概率值与模型正确预测该类别样本占比的差异。

Calibrating deep neural networks using focal loss(NIPS 2020)中,作者提出直接使用focal loss替代交叉熵损失,就可以起到校准作用。Focal loss是表示学习中的常用函数,对focal loss不了解的同学可以参考之前的文章:表示学习中的7大损失函数梳理。作者对focal loss进行推倒,可以拆解为如下两项,分别是预测分布与真实分布的KL散度,以及预测分布的熵。KL散度和一般的交叉熵作用相同,而第二项在约束模型输出的预测概率值熵尽可能大,其实和temperature scaling的原理类似,都是缓解模型在某个类别上打分太高而带来的过自信问题:

d2c6ff68-2f5d-11ed-ba43-dac502259ad0.png

除了修改损失函数实现校准的方法外,label smoothing也是一种常用的校准方法,最早在Regularizing neural networks by penalizing confident output distributions(ICLR 2017)中提出了label smoothing在模型校准上的应用,后来又在When does label smoothing help? (NIPS 2019)进行了更加深入的探讨。Label smoothing通过如下公式对原始的label进行平滑操作,其原理也是增大输出概率分布的熵:

d2d8641a-2f5d-11ed-ba43-dac502259ad0.png

此外,一些研究也研究了数据增强手段对模型校准的影响。On mixup training: Improved calibration and predictive uncertainty for deep neural networks(NIPS 2019)提出mixup方法可以有效提升模型校准程度。Mixup是一种简单有效的数据增强策略,具体实现上,随机从数据集中抽取两个样本,将它们的特征和label分别进行加权融合,得到一个新的样本用于训练:

d2e354d8-2f5d-11ed-ba43-dac502259ad0.png

文中作者提出,上面融合过程中对label的融合对取得校准效果好的预测结果是非常重要的,这和上面提到的label smoothing思路比较接近,让label不再是0或1的超低熵分布,来缓解模型过自信问题。

类似的方法还包括CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features(ICCV 2019)提出的一种对Mixup方法的扩展,随机选择两个图像和label后,对每个patch随机选择是否使用另一个图像相应的patch进行替换,也起到了和Mixup类似的效果。文中也对比了Mixup和CutMix的效果,Mixup由于每个位置都进行插值,容易造成区域信息的混淆,而CutMix直接进行替换,不同区域的差异更加明确。

d2f70370-2f5d-11ed-ba43-dac502259ad0.png

4 总结

本文梳理了深度学习模型的校准方法,包含10篇经典论文的工作。通过校准,可以让模型输出的预测概率更加可信,可以应用于各种类型、各种场景的深度学习模型中,适用场景非常广泛。




审核编辑:刘清

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4842

    浏览量

    108180

原文标题:不要相信模型输出的概率打分......

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    揭秘TEE深度休眠唤醒“低概率报错”:从概念到解决方案的全解析

    在嵌入式与物联网设备的底层技术领域,TEE(可信执行环境) 是保障系统安全的关键组件之一。但在 RK3562、RK3588 等芯片的深度休眠唤醒场景中,却出现了一类 “低概率却影响致命” 的报错问题。今天我们就从概念入手,一步步
    的头像 发表于 02-09 16:37 341次阅读
    揭秘TEE<b class='flag-5'>深度</b>休眠唤醒“低<b class='flag-5'>概率</b>报错”:从概念到解决方案的全解析

    从数据到模型:如何预测细节距键合的剪切力?

    在微电子封装领域,细节距键合工艺的开发与质量控制面临着巨大挑战。工程师们常常需要在缺乏大量破坏性测试的前提下,快速评估或预测一个键合点的剪切力性能。能否根据焊球的表观尺寸,通过一个可靠的数学模型
    发表于 01-08 09:45

    机器学习深度学习中需避免的 7 个常见错误与局限性

    无论你是刚入门还是已经从事人工智能模型相关工作一段时间,机器学习深度学习中都存在一些我们需要时刻关注并铭记的常见错误。如果对这些错误置之不理,日后可能会引发诸多麻烦!只要我们密切关注
    的头像 发表于 01-07 15:37 350次阅读
    机器<b class='flag-5'>学习</b>和<b class='flag-5'>深度</b><b class='flag-5'>学习</b>中需避免的 7 个常见错误与局限性

    穿孔机顶头检测仪 机器视觉深度学习

    ,能适用恶劣工况,在粉尘、高温、氧化皮等恶劣环境中均可正常工作。 测量原理 利用顶头与周围的物质(水、空气、导盘等)红外辐射能量的差异,用热成像相机拍摄出清晰的图片,再通过深度学习短时间内深度
    发表于 12-22 14:33

    模型赋能物资需求精准预测与采购系统:功能特点与平台架构解析

        大模型赋能物资需求预测与采购智能化:核心功能与价值解析    大模型赋能物资需求精准预测与采购系统通过深度整合多源数据、构建动态
    的头像 发表于 12-16 11:54 465次阅读

    世界模型是让自动驾驶汽车理解世界还是预测未来?

      [首发于智驾最前沿微信公众号]世界模型在自动驾驶技术中已有广泛应用。但当谈及它对自动驾驶的作用时,难免会出现分歧。它到底是让自动驾驶汽车得以理解世界,还是为其提供了预测未来的视角? 世界
    的头像 发表于 12-16 09:27 1020次阅读
    世界<b class='flag-5'>模型</b>是让自动驾驶汽车理解世界还是<b class='flag-5'>预测</b>未来?

    攻击逃逸测试:深度验证网络安全设备的真实防护能力

    攻击逃逸测试通过主动模拟协议混淆、流量分割、时间延迟等高级规避技术,能够深度验证网络安全设备的真实防护能力。这种测试方法不仅能精准暴露检测引擎的解析盲区和策略缺陷,还能有效评估防御体系在面对隐蔽攻击
    发表于 11-17 16:17

    设备出现通信问题的概率大吗?

    设备出现通信问题的概率并非固定值,而是受 “通信链路类型(有线 / 无线)、应用场景(工业 / 民用)、设备老化程度、设计安装规范度、干扰源强度” 等多因素影响,整体呈现 “工业场景高于民用
    的头像 发表于 09-25 14:08 749次阅读
    设备<b class='flag-5'>出现</b>通信问题的<b class='flag-5'>概率</b>大吗?

    如何在机器视觉中部署深度学习神经网络

    图 1:基于深度学习的目标检测可定位已训练的目标类别,并通过矩形框(边界框)对其进行标识。 在讨论人工智能(AI)或深度学习时,经常会出现
    的头像 发表于 09-10 17:38 1052次阅读
    如何在机器视觉中部署<b class='flag-5'>深度</b><b class='flag-5'>学习</b>神经网络

    基于极海APM32F103的USB键盘与虚拟串口例程

    最近在编写DMA_ADC例程的过程中出现了一个中断配置的问题,在ADC采集过程中,结合手册进行ADC连续转换模式配置采集,手册上给出需要进行中断配置的信息,但是真实情况不需要进行中断配置也可以进行
    的头像 发表于 08-16 09:20 1759次阅读
    基于极海APM32F103的USB键盘与虚拟串口例程

    自动驾驶中Transformer大模型会取代深度学习吗?

    [首发于智驾最前沿微信公众号]近年来,随着ChatGPT、Claude、文心一言等大语言模型在生成文本、对话交互等领域的惊艳表现,“Transformer架构是否正在取代传统深度学习”这一话题一直被
    的头像 发表于 08-13 09:15 4367次阅读
    自动驾驶中Transformer大<b class='flag-5'>模型</b>会取代<b class='flag-5'>深度</b><b class='flag-5'>学习</b>吗?

    晶圆切割深度动态补偿的智能决策模型与 TTV 预测控制

    摘要:本文针对超薄晶圆切割过程中 TTV 均匀性控制难题,研究晶圆切割深度动态补偿的智能决策模型与 TTV 预测控制方法。分析影响切割深度与 TTV 的关键因素,阐述智能决策
    的头像 发表于 07-23 09:54 692次阅读
    晶圆切割<b class='flag-5'>深度</b>动态补偿的智能决策<b class='flag-5'>模型</b>与 TTV <b class='flag-5'>预测</b>控制

    瑞芯微3576,使用FP16模型进行训练,瑞芯微官方接口概率崩溃

    corrupted。之前使用INT8的模型出现过这个错误。使用的是model_zoo中的aarch64下的librknnrt.so。未崩溃时能正常检测,可以确认崩溃前没有内存泄漏或者不足的情况
    发表于 07-17 13:25

    模型推理显存和计算量估计方法研究

    方法。 一、引言 大模型推理是指在已知输入数据的情况下,通过深度学习模型进行预测或分类的过程。然
    发表于 07-03 19:43

    基于APM32F411 DMA_ADC Handler模式分析及解决

    最近在编写DMA_ADC例程的过程中出现了一个中断配置的问题,在ADC采集过程中,结合手册进行ADC连续转换模式配置采集,手册上给出需要进行中断配置的信息,但是真实情况不需要进行中断配置也可以进行
    的头像 发表于 06-24 14:30 1220次阅读
    基于APM32F411 DMA_ADC Handler模式分析及解决