0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

机器学习从业者指出了一个明显的问题:你如何调试模型?

智能感知与物联网技术研究所 来源:lp 2019-04-01 15:17 次阅读

当你花了几个星期构建一个数据集、编码一个神经网络并训练好了模型,然后发现结果并不理想,接下来你会怎么做?

深度学习通常被视为一个黑盒子,我并不反对这种观点——但是你能讲清楚学到的上万参数的意义吗?

但是黑盒子的观点为机器学习从业者指出了一个明显的问题:你如何调试模型?

在这篇文章中,我将会介绍一些我们在 Cardiogram 中调试 DeepHeart 时用到的技术,DeepHeart 是使用来自 Apple Watch、 Garmin、和 WearOS 的数据预测疾病的。

在 Cardiogram 中,我们认为构建 DNN 并不是炼金术,而是工程学。

你的心脏暴露了很多你的信息。DeepHeart 使用来自 Apple Watch、 Garmin、和 WearOS 的心率数据来预测你患糖尿病、高血压以及睡眠窒息症(sleep apnea)的风险。

一、预测合成输出

通过预测根据输入数据构建的合成输出任务来测试模型能力。

我们在构建检测睡眠窒息症的模型时使用了这个技术。现有关于睡眠窒息症筛查的文献使用日间和夜间心率标准差的差异作为筛查机制。因此我们为每周的输入数据创建了合成输出任务:

标准差 (日间心率)—标准差 (夜间心率)

为了学习这个函数,模型要能够:

1. 区分白天和黑夜

2. 记住过去几天的数据

这两个都是预测睡眠窒息症的先决条件,所以我们使用新架构进行实验的第一步就是检查它是否能学习这个合成任务。

你也可以通过在合成任务上预训练网络,以半监督的形式来使用类似这样的合成任务。当标记数据很稀缺,而你手头有大量未标记数据时,这种方法很有用。

二、可视化激活值

理解一个训练好的模型的内部机制是很难的。你如何理解成千上万的矩阵乘法呢?

在这篇优秀的 Distill 文章《Four Experiments in Handwriting with a Neural Network》中,作者通过在热图中绘制单元激活值,分析了手写模型。我们发现这是一个「打开 DNN 引擎盖」的好方法。

我们检查了网络中几个层的激活值,希望能够发现一些语义属性,例如,当用户在睡觉、工作或者焦虑时,激活的单元是怎样的?

用 Keras 写的从模型中提取激活值的代码很简单。下面的代码片段创建了一个 Keras 函数 last_output_fn,该函数在给定一些输入数据的情况下,能够获得一层的输出(即它的激活值)。

fromkerasimportbackendasKdefextract_layer_output(model,layer_name,input_data):layer_output_fn=K.function([model.layers[0].input],[model.get_layer(layer_name).output])layer_output=layer_output_fn([input_data])#layer_output.shapeis(num_units,num_timesteps)returnlayer_output[0]

我们可视化了网络好几层的激活值。在检查第二个卷积层(一个宽为 128 的时间卷积层)的激活值时,我们注意到了一些奇怪的事:

卷积层的每个单元在每个时间步长上的激活值。蓝色的阴影代表的是激活值。

激活值竟然不是随着时间变化的!它们不受输入值影响,被称为「死神经元」。

ReLU 激活函数,f(x) = max(0, x)

这个架构使用了 激活函数,当输入是负数的时候它输出的是 0。尽管它是这个神经网络中比较浅的层,但是这确实是实际发生的事情。

在训练的某些时候,较大的梯度会把某一层的所有偏置项都变成负数,使得 ReLU 函数的输入是很小的负数。因此这层的输出就会全部为 0,因为对小于 0 的输入来说,ReLU 的梯度为零,这个问题无法通过来解决。

当一个卷积层的输出全部为零时,后续层的单元就会输出其偏置项的值。这就是这个层每个单元输出一个不同值的原因——因为它们的偏置项不同。

我们通过用 Leaky ReLU 替换 ReLU 解决了这个问题,前者允许梯度传播,即使输入为负时。

我们没想到会在此次分析中发现「死神经元」,但最难找到的错误是你没打算找的。

三、梯度分析

梯度的作用当然不止是优化损失函数。在梯度下降中,我们计算与Δparameter 对应的Δloss。尽管通常意义上梯度计算的是改变一个变量对另一个变量的影响。由于梯度计算在梯度下降方法中是必需的,所以像 TensorFlow 这样的框架都提供了计算梯度的函数。

我们使用梯度分析来确定我们的深度神经网络能否捕捉数据中的长期依赖。DNN 的输入数据特别长:4096 个时间步长的心率或者计步数据。我们的模型架构能否捕捉数据中的长期依赖非常重要。例如,心率的恢复时间可以预测糖尿病。这就是锻炼后恢复至休息时的心率所耗的时间。为了计算它,深度神经网络必须能够计算出你休息时的心率,并记住你结束锻炼的时间。

衡量模型能否追踪长期依赖的一种简单方法是去检查输入数据的每个时间步长对输出预测的影响。如果后面的时间步长具有特别大的影响,则说明模型没有有效地利用早期数据。

对于所有时间步长 t,我们想要计算的梯度是与Δinput_t 对应的Δoutput。下面是用 Keras 和 TensorFlow 计算这个梯度的代码示例:

defgradient_output_wrt_input(model,data):#[:,2048,0]meansallusersinbatch,midpointtimestep,0thtask(diabetes)output_tensor=model.model.get_layer('raw_output').output[:,2048,0]#output_tensor.shape==(num_users)#Averageoutputoverallusers.Resultisascalar.output_tensor_sum=tf.reduce_mean(output_tensor)inputs=model.model.inputs#(num_usersxnum_timestepsxnum_input_channels)gradient_tensors=tf.gradients(output_tensor_sum,inputs)#gradient_tensors.shape==(num_usersxnum_timestepsxnum_input_channels)#Averageoverusersgradient_tensors=tf.reduce_mean(gradient_tensors,axis=0)#gradient_tensors.shape==(num_timestepsxnum_input_channels)#eggradient_tensor[10,0]isderivoflastoutputwrt10thinputheartrate#ConverttoKerasfunctionk_gradients=K.function(inputs=inputs,outputs=gradient_tensors)#Applyfunctiontodatasetreturnk_gradients([data.X])

在上面的代码中,我们在平均池化之前,在中点时间步长 2048 处计算了输出。我们之所以使用中点而不是最后的时间步长的原因是,我们的 LSTM 单元是双向的,这意味着对一半的单元来说,4095 实际上是第一个时间步长。我们将得到的梯度进行了可视化:

Δoutput_2048 / Δinput_t

请注意我们的 y 轴是 log 尺度的。在时间步长 2048 处,与输入对应的输出梯度是 0.001。但是在时间步长 2500 处,对应的梯度小了一百万倍!通过梯度分析,我们发现这个架构无法捕捉长期依赖。

四、分析模型预测

你可能已经通过观察像 AUROC 和平均绝对误差这样的指标分析了模型预测。你还可以用更多的分析来理解模型的行为。

例如,我们好奇 DNN 是否真的用心率输入来生成预测,或者说它的学习是不是严重依赖于所提供的元数据——我们用性别、年龄这样的用户元数据来初始化 LSTM 的状态。为了理解这个,我们将模型与在元数据上训练的 logistic 回归模型做了对比。

DNN 模型接收了一周的用户数据,所以在下面的散点图中,每个点代表的是一个用户周。

这幅图验证了我们的猜想,因为预测结果并不是高度相关的。

除了进行汇总分析,查看最好和最坏的样本也是很有启发性的。对一个二分类任务而言,你需要查看最令人震惊的假阳性和假阴性(也就是预测距离标签最远的情况)。尝试鉴别损失模式,然后过滤掉在你的真阳性和真阴性中出现的这种模式。

一旦你对损失模式有了假设,就通过分层分析进行测试。例如,如果最高损失全部来自第一代 Apple Watch,我们可以用第一代 Apple Watch 计算我们的调优集中用户集的准确率指标,并将这些指标与在剩余调优集上计算的指标进行比较。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4575

    浏览量

    98775
  • 机器学习
    +关注

    关注

    66

    文章

    8134

    浏览量

    130577
  • 数据集
    +关注

    关注

    4

    文章

    1179

    浏览量

    24356

原文标题:你用什么方法调试深度神经网络?这里有四种简单的方式哦

文章出处:【微信号:tyutcsplab,微信公众号:智能感知与物联网技术研究所】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    10个问题及答案保障LED从业者用电安全

    照明开关为何必须接在火线上?单相三孔插座如何安装才正确?为什么?塑料绝缘导线为什么严禁直接埋在墙内?本文从十个问题以及答案方面解析了LED从业者需要注意的用电安全。
    发表于 02-24 09:41 928次阅读

    开关电源从业者必看资料——《开关电源常规测试项目》

    开关电源从业者必看资料——《开关电源常规测试项目》
    发表于 04-29 22:03

    25机器学习面试题,都会吗?

    问题都没有给出明确的答案,但都有定的提示。读者也可以在留言中尝试。许多数据科学家主要是从数据从业者的角度来研究机器
    发表于 09-29 09:39

    电源从业者必看必会之变压器基础知识_制作流程_详解

    适合电源从业者的基础知识入门维修必看
    发表于 11-10 20:42

    电源从业者必知必会之12种开关电源拓扑及计算公式

    =oxh_wx3、【周启全老师】开关电源全集http://t.elecfans.com/topic/130.html?elecfans_trackid=oxh_wx 电源从业者必知必会12种开关电源拓扑及计算公式
    发表于 06-02 22:03

    开关电源从业者入门资料 开关电源结构和基本原理

    非常适合开关电源从业者入门的资料 开关电源结构和基本原理资料分享来自于网络资源
    发表于 09-05 21:42

    软件测试从业者需要具备哪些技能

    测试从业者,仅仅会些硬技能还不够。还需要具备些软技能。软技能质量意识(很多时候,团队中,并不缺技术,唯独缺质量意识)好的工作习惯(每天把不懂的内容,用本子记下来,弄懂为止,几年后
    发表于 07-16 16:22

    软件测试从业者需要具备哪些技能

    测试从业者,仅仅会些硬技能还不够。还需要具备些软技能。软技能质量意识(很多时候,团队中,并不缺技术,唯独缺质量意识)好的工作习惯(每天把不懂的内容,用本子记下来,弄懂为止,几年后
    发表于 11-23 10:00

    电池配组方案(电池修复从业者必读)

    电池配组方案(电池修复从业者必读)    很多朋友来问:为什么电池修复好,经过测试性能相当
    发表于 11-16 13:44 4955次阅读

    机器学习从业者工具使用方面大数据分析

    数据科学是个变化极快的领域,业内人员需要不断更新知识体系,才可以在业内保持一定地位,不被时代淘汰。Stack Overflow Q&A、Conferences 和 Podcasts 是已从业者经常使用的学习平台。
    的头像 发表于 12-04 16:34 3901次阅读
    <b class='flag-5'>机器</b><b class='flag-5'>学习</b><b class='flag-5'>从业者</b>工具使用方面大数据分析

    NVIDIA持续助力AI教育及研究从业者

    AI教育关乎未来发展。一直以来,NVIDIA都坚持与AI教育和研究从业者并肩前行,以促进AI技术的普及,迎接AI驱动下的未来经济。
    的头像 发表于 09-12 14:09 3025次阅读

    刘铁岩谈机器学习:随波逐流的太多

    机器学习从业者在当下需要掌握哪些前沿技术?展望未来,又会有哪些技术趋势值得期待?
    的头像 发表于 01-05 10:58 2593次阅读

    已入冰点的通信行业 从业者该清醒清醒了

    伴随着营业收入出现拐点,通信行业已经跌入冰点。面对如此不利的经营形式,通信行业从业企业和从业者该醒醒了。
    的头像 发表于 07-12 17:04 2270次阅读

    ML从业者如何阅读研究论文

      数据科学专业的性质是非常实用和涉及的。这意味着,数据科学领域与人工智能密切相关,人工智能仍然是一个发展中的领域,因此,它的从业者必须具备学术思维。
    的头像 发表于 04-08 14:22 648次阅读
    ML<b class='flag-5'>从业者</b>如何阅读研究论文

    作为IT从业者,为什么我推荐华为云ECS?

    作为一名IT从业者,我经常会被问到关于云计算的问题。在这里,我想分享一下我的看法,并向大家推荐华为云ECS。 首先,让我们来了解一下什么是弹性云服务器。弹性云服务器(Elastic Cloud
    的头像 发表于 06-24 00:26 377次阅读