0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

利用TensorFlow.js,D3.js 和 Web 的力量使训练模型的过程可视化

Tensorflowers 来源:未知 作者:李倩 2018-08-08 14:24 次阅读

在这篇文章中,我们将利用 TensorFlow.js,D3.js 和 Web 的力量使训练模型的过程可视化,以预测棒球数据中的坏球(蓝色区域)和好球(橙色区域)。 随着我们的进展,我们将模型在整个训练过程中理解的打击区域可视化。您可以通过访问此 Observable 笔记本在浏览器中运行此模型。

注:Observable链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d

如果你不熟悉棒球的击球区,这里有一篇详细的文章。

上面的 GIF 可视化神经网络学习调用坏球(蓝色区域)和好球(橙色区域)在每个训练步骤之后,热图会根据模型的预测进行更新

使用 Observable 直接在浏览器中运行此模型。

注:文章链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d

体育运动中的高级指标

当今的职业体育环境中充斥着大量的数据。这些数据被团队,业余爱好者和粉丝应用于各种用例中。感谢像 TensorFlow 这样的框架 - 这些数据集已准备好应用于机器学习

美国职业棒球大联盟先进媒体(MLBAM)的 PITCHf/x

美国职业棒球大联盟先进媒体(MLBAM)发布了一个可供公众研究的大型数据集。该数据集包含有关过去几年在美国职业棒球大联盟比赛中投掷的投球的传感器信息。 利用这个数据集,我们已编写了一个包含 5,000 个样本的训练集(2,500 个坏球和 2,500 个好球)。

以下是训练数据中前几个字段的示例:

注:示例链接

https://gist.github.com/nkreeger/01b5386b522b0cd1f22bc864320f3084#file-baseball-training-data-sample-csv

以下是针对打击区域绘制的训练数据的样子。蓝点标记为坏球,橙点标记为好球(此为大联盟裁判员称谓):

利用 TensorFlow.js 构建模型

TensorFlow.js 将机器学习引入 JavaScript 和 Web。 我们将利用这个很棒的框架来构建一个深度神经网络模型。这个模型将能够按大联盟裁判的精准度来称呼好球和坏球。

输入 Input

该模型在 PITCHf / x 的以下字段中进行了训练:

协调球越过本垒的位置('px'和'pz')。

击球手站在垒的哪一侧。

击球区(击球手的躯干)的高度,以英尺为单位。

击球区底部的高度(击球手的膝盖)以英尺为单位。

裁判所称的投球(好球或坏球)的实际标签

结构 Architecture

该模型将通过使用 TensorFlow.js 图层 API 定义。Layers API 基于 Keras,对以前使用过该框架的人来说应该很熟悉:

1const model = tf.sequential();

2

3// Two fully connected layers with dropout between each:

4model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));

5model.add(tf.layers.dropout({rate: 0.01}));

6model.add(tf.layers.dense({units: 16, activation: 'relu'}));

7model.add(tf.layers.dropout({rate: 0.01}));

8

9// Only two classes: "strike" and "ball":

10model.add(tf.layers.dense({units: 2, activation: 'softmax'}));

11

12model.compile({

13optimizer: tf.train.adam(0.01),

14loss: 'categoricalCrossentropy',

15metrics: ['accuracy']

16});

加载和准备数据

精选的训练集可通过GitHub gist 获得。需要下载此数据集才能开始将 CSV 数据转换为 TensorFlow.js 用于训练的格式。

注:GitHub gist 链接

https://gist.github.com/nkreeger/43edc6e6daecc2cb02a2dd3293a08f29

1const data = [];

2csvData.forEach((values) => {

3// 'logit' data uses the 5 fields:

4const x = [];

5x.push(parseFloat(values.px));

6x.push(parseFloat(values.pz));

7x.push(parseFloat(values.sz_top));

8x.push(parseFloat(values.sz_bot));

9x.push(parseFloat(values.left_handed_batter));

10// The label is simply 'is strike' or 'is ball':

11const y = parseInt(values.is_strike, 10);

12data.push({x: x, y: y});

13});

14// Shuffle the contents to ensure the model does not always train on the same

15// sequence of pitch data:

16tf.util.shuffle(data);

解析 CSV 数据后,需要将 JS 类型转换为 Tensor 批次进行培训和评估。有关此过程的详细信息,请参阅代码实验室。TensorFlow.js 团队正在开发一种新的 Data API,以便将来更容易获取。

注:代码实验室

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#batches

训练模型

让我们把这一切都整合在一起吧。定义了模型,准备好了训练数据,现在我们已经准备好开始训练了。以下异步方法训练一批训练样本并更新热图:

1// Trains and reports loss+accuracy for one batch of training data:

2async function trainBatch(index) {

3const history = await model.fit(batches[index].x, batches[index].y, {

4epochs: 1,

5shuffle: false,

6validationData: [batches[index].x, batches[index].y],

7batchSize: CONSTANTS.BATCH_SIZE

8});

9

10// Don't block the UI frame by using tf.nextFrame()

11await tf.nextFrame();

12updateHeatmap();

13await tf.nextFrame();

14}

可视化模型的准确性

使用来自均匀放置在本垒板上方的 4 x 4 英尺栅格的预测矩阵来构建热图。在每个训练步骤之后将该矩阵传递到模型中以检查模型的准确度。使用 D3 库将该预测的结果呈现为热图。

构建预测矩阵

热图中使用的预测矩阵从本垒板的中间开始,向左和向右各延伸 2 英尺。它的范围也从本垒板的底部到 4 英尺高。击打区样本位于本垒板上方 1.5 至 3.5 英尺之间。下图有助于让这些 2d 窗格可视化:

该视觉显示了打击区域和预测矩阵与本垒板和游戏区域相关的位置

将预测矩阵与模型一起使用

每个批次在模型中训练之后,预测矩阵被传递到模型中用以请求矩阵中的好球或坏球预测:

1function predictZone() {

2const predictions = model.predictOnBatch(predictionMatrix.data);

3const values = predictions.dataSync();

4

5// Sort each value so the higher prediction is the first element in the array:

6const results = [];

7let index = 0;

8for (let i = 0; i < values.length; i++) {    

9let list = [];

10list.push({value: values[index++], strike: 0});

11list.push({value: values[index++], strike: 1});

12list = list.sort((a, b) => b.value - a.value);

13results.push(list);

14}

15return results;

16}

热图与 D3

现在可以使用 D3 显示预测结果。 来自 50x50 网格中的每一个元素将在 SVG 中呈现为 10px x 10px 的矩形。每个矩形的颜色取决于预测结果(好球或者坏球)以及模型对该结果的确定程度(范围从 50%-100%)。 以下代码段显示了如何从 D3 svg 矩形分组更新数据:

1function updateHeatmap() {

2rects.data(generateHeatmapData());

3rects

4.attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })

5.attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })

6.attr('width', CONSTANTS.HEATMAP_SIZE)

7.attr('height', CONSTANTS.HEATMAP_SIZE)

8.style('fill', (coord) => {

9if (coord.strike) {

10return strikeColorScale(coord.value);

11} else {

12return ballColorScale(coord.value);

13}

14});

15}

有关使用 D3 绘制热图的完整详细信息,请参阅此部分。

注:此部分链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#colorDomain

总结

网络上有许多令人惊叹的第三方库和工具,可用于创建视觉效果。将这些与机器学习的强大功能与 TensorFlow.js 相结合,开发人员能够创建一些非常新奇有趣的演示。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4570

    浏览量

    98706
  • 机器学习
    +关注

    关注

    66

    文章

    8105

    浏览量

    130542
  • tensorflow
    +关注

    关注

    13

    文章

    313

    浏览量

    60242

原文标题:棒球比赛中是好球还是坏球?TensorFlow.js 已经知道

文章出处:【微信号:tensorflowers,微信公众号:Tensorflowers】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    可视化MES系统软件

    系统架构分析系统的功能模型使装配车间的生产控、质量控制和数据实时更新和历史统计都可视化展示,改进了生产过程中组装线和工艺流程的混乱、质量反馈问题的滞后和交互,
    发表于 11-30 19:55

    Pytorch模型训练实用PDF教程【中文】

    模型部分?还是优化器?只有这样不断的通过可视化诊断你的模型,不断的对症下药,才能训练出一个较满意的模型。本教程内容及结构:本教程内容主要为
    发表于 12-21 09:18

    Tensorflow之Tensorboard的可视化使用

    TF之Tensorboard:Tensorflow之Tensorboard可视化使用之详细攻略
    发表于 12-27 10:05

    基于Keras利用训练好的hdf5模型进行目标检测实现输出模型中的表情或性别gradcam

    CV:基于Keras利用训练好的hdf5模型进行目标检测实现输出模型中的脸部表情或性别的gradcam(可视化)
    发表于 12-27 16:48

    TensorFlow TensorBoard可视化数据流图

    间变化的:还可以使用 tf.summary.histogram 可视化梯度、权重或特定层的输出分布:摘要将在会话操作中生成。可以在计算图中定义 tf.merge_all_summaries OP 来
    发表于 07-22 21:26

    VR与三维可视化在电厂中的作用

    才会拥有的体验。智能培训模块采用虚拟现实技术,利用数字电厂的数据,建立相应的软硬件平台,利用定制的电厂三维模型,方便快捷地实现三维展示或者VR体验,提供安全的多用户虚拟世界,通过交互
    发表于 12-03 15:03

    数字可视化Web组态软件有哪些

    数字可视化Web组态软件有哪些?都有何优缺点?
    发表于 09-26 08:19

    Keras可视化神经网络架构的4种方法

    我们在使用卷积神经网络或递归神经网络或其他变体时,通常都希望对模型的架构可以进行可视化的查看,因为这样我们可以 在定义和训练多个模型时,比较不同的层以及它们放置的顺序对结果的影响。还有
    发表于 11-02 14:55

    keras可视化介绍

    keras可视化可以帮助我们直观的查看所搭建的模型拓扑结构,以及模型训练过程,方便我们优化模型
    发表于 08-18 07:53

    node.js训练好的神经网络模型识别图像中物体的方法

    如何在Node.js环境下使用训练好的神经网络模型(Inception、SSD)识别图像中的物体。
    的头像 发表于 04-06 13:11 8714次阅读

    TensorFlow发表推文正式发布TensorFlow v1.9

    其中有两个案例受到了大家的广泛关注,这个项目是通过 Colab 在 tf.keras 中训练模型,并通过TensorFlow.js 在浏览器中运行;最近在 JS 社区中,对这些相关项目
    的头像 发表于 07-16 10:23 2898次阅读

    如何使用TensorFlow.js构建这一系统

    TensorFlow.js团队一直在进行有趣的基于浏览器的实验,以使人们熟悉机器学习的概念,并鼓励他们将其用作您自己项目的构建块。对于那些不熟悉的人来说,TensorFlow.js是一个开源库,允许
    的头像 发表于 08-19 08:55 3336次阅读

    基于tensorflow.js设计、训练面向web的神经网络模型的经验

    了NVIDIA显卡。tensorflow.js在底层使用了WebGL加速,所以在浏览器中训练模型的一个好处是可以利用AMD显卡。另外,在浏览器中训练
    的头像 发表于 10-18 09:43 3866次阅读

    TensorFlow.js制作了一个仅用 200 余行代码的项目

    我们先来看一下运行的效果。下图中,上半部分是原始视频,下半部分是使用 TensorFlow.js 对人像进行消除后的视频。可以看到,除了偶尔会在边缘处留有残影之外,整体效果还是很不错的。
    的头像 发表于 05-11 18:08 5430次阅读

    如何基于 ES6 的 JavaScript 进行 TensorFlow.js 的开发

    从头开发、训练和部署模型,也可以用来运行已有的 Python 版 TensorFlow 模型,或者基于现有的模型进行继续
    的头像 发表于 10-31 11:16 2854次阅读