0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

使用TensorFlow决策森林创建提升树模型

谷歌开发者 来源:TensorFlow 作者:TensorFlow 2022-04-19 10:46 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

发布人:TensorFlow 团队的 Mathieu Guillame-Bert 和 Josh Gordon

随机森林和梯度提升树这类的决策森林模型通常是处理表格数据最有效的可用工具。与神经网络相比,决策森林具有更多优势,如配置过程更轻松、训练速度更快等。使用树可大幅减少准备数据集所需的代码量,因为这些树本身就可以处理数字、分类和缺失的特征。此外,这些树通常还可提供开箱即用的良好结果,并具有可解释的属性。

尽管我们通常将 TensorFlow 视为训练神经网络的内容库,但 Google 的一个常见用例是使用 TensorFlow 创建决策森林。

08660ec6-bf00-11ec-9e50-dac502259ad0.gif

对数据开展分类的决策树动画

如果您曾使用 2019 年推出tf.estimator.BoostedTrees 创建基于树的模型,您可参考本文所提供的指南进行迁移。虽然 Estimator API 基本可以应对在生产环境中使用模型的复杂性,包括分布式训练和序列化,但是我们不建议您将其用于新代码。

如果您要开始一个新项目,我们建议您使用 TensorFlow 决策森林 (TF-DF)。该内容库可为训练、服务和解读决策森林模型提供最先进的算法,相较于先前的方法更具优势,特别是在质量、速度和易用性方面表现尤为出色。

首先,让我们来比较一下使用 Estimator API 和 TF-DF 创建提升树模型的等效示例。

以下是使用 tf.estimator.BoostedTrees 训练梯度提升树模型的旧方法(不再推荐使用)

import tensorflow as tf

# Dataset generators
def make_dataset_fn(dataset_path):
    def make_dataset():
        data = ... # read dataset
        return tf.data.Dataset.from_tensor_slices(...data...).repeat(10).batch(64)
    return make_dataset

# List the possible values for the feature "f_2".
f_2_dictionary = ["NA", "red", "blue", "green"]

# The feature columns define the input features of the model.
feature_columns = [
    tf.feature_column.numeric_column("f_1"),
    tf.feature_column.indicator_column(
       tf.feature_column.categorical_column_with_vocabulary_list("f_2",
         f_2_dictionary,
         # A special value "missing" is used to represent missing values.
         default_value=0)
       ),
    ]

# Configure the estimator
estimator = boosted_trees.BoostedTreesClassifier(
          n_trees=1000,
          feature_columns=feature_columns,
          n_classes=3,
          # Rule of thumb proposed in the BoostedTreesClassifier documentation.
          n_batches_per_layer=max(2, int(len(train_df) / 2 / FLAGS.batch_size)),
      )

# Stop the training is the validation loss stop decreasing.
early_stopping_hook = early_stopping.stop_if_no_decrease_hook(
      estimator,
      metric_name="loss",
      max_steps_without_decrease=100,
      min_steps=50)

tf.estimator.train_and_evaluate(
      estimator,
      train_spec=tf.estimator.TrainSpec(
          make_dataset_fn(train_path),
          hooks=[
              # Early stopping needs a CheckpointSaverHook.
              tf.train.CheckpointSaverHook(
                  checkpoint_dir=input_config.raw.temp_dir, save_steps=500),
              early_stopping_hook,
          ]),
      eval_spec=tf.estimator.EvalSpec(make_dataset_fn(valid_path)))

使用 TensorFlow 决策森林训练相同的模型

import tensorflow_decision_forests as tfdf

# Load the datasets
# This code is similar to the estimator.
def make_dataset(dataset_path):
    data = ... # read dataset
    return tf.data.Dataset.from_tensor_slices(...data...).batch(64)

train_dataset = make_dataset(train_path)
valid_dataset = make_dataset(valid_path)

# List the input features of the model.
features = [
  tfdf.keras.FeatureUsage("f_1", keras.FeatureSemantic.NUMERICAL),
  tfdf.keras.FeatureUsage("f_2", keras.FeatureSemantic.CATEGORICAL),
]

model = tfdf.keras.GradientBoostedTreesModel(
  task = tfdf.keras.Task.CLASSIFICATION,
  num_trees=1000,
  features=features,
  exclude_non_specified_features=True)

model.fit(train_dataset, valid_dataset)

# Export the model to a SavedModel.
model.save("project/model")

附注

  • 虽然在此示例中没有明确说明,但 TensorFlow 决策森林可自动启用和配置早停

  • 可自动构建和优化“f_2”特征字典(例如,将稀有值合并到一个未登录词项目中)。

  • 可从数据集中自动确定类别数(本例中为 3 个)。

  • 批次大小(本例中为 64)对模型训练没有影响。以较大值为宜,因为这可以增加读取数据集的效率。

TF-DF 的亮点就在于简单易用,我们还可进一步简化和完善上述示例,如下所示。

如何训练 TensorFlow 决策森林(推荐解决方案)

import tensorflow_decision_forests as tfdf
import pandas as pd

# Pandas dataset can be used easily with pd_dataframe_to_tf_dataset.
train_df = pd.read_csv("project/train.csv")

# Convert the Pandas dataframe into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")

model = tfdf.keras.GradientBoostedTreeModel(num_trees=1000)
model.fit(train_dataset)

附注

  • 我们未指定特征的语义(例如数字或分类)。在这种情况下,系统将自动推断语义。

  • 我们也没有列出要使用的输入特征。在这种情况下,系统将使用所有列(标签除外)。可在训练日志中查看输入特征的列表和语义,或通过模型检查器 API 查看。

  • 我们没有指定任何验证数据集。每个算法都可以从训练样本中提取一个验证数据集作为算法的最佳选择。例如,默认情况下,如果未提供验证数据集,则 GradientBoostedTreeModel 将使用 10% 的训练数据进行验证。

下面我们将介绍 Estimator API 和 TF-DF 的一些区别。

Estimator API 和 TF-DF 的区别

算法类型

TF-DF 是决策森林算法的集合,包括(但不限于)Estimator API 提供的梯度提升树。请注意,TF-DF 还支持随机森林(非常适用于干扰数据集)和 CART 实现(非常适用于解读模型)。

此外,对于每个算法,TF-DF 都包含许多在文献资料中发现并经过实验验证的变体 [1, 2, 3]。

精确与近似分块的对比

TF1 GBT Estimator 是一种近似的树学习算法。非正式情况下,Estimator 通过仅考虑样本的随机子集和每个步骤条件的随机子集来构建

默认情况下,TF-DF 是一种精确的树训练算法。非正式情况下,TF-DF 会考虑所有训练样本和每个步骤的所有可能分块。这是一种更常见且通常表现更佳的解决方案。

虽然对于较大的数据集(具有百亿数量级以上的“样本和特征”数组)而言,有时 Estimator 的速度更快,但其近似值通常不太准确(因为需要种植更多树才能达到相同的质量)。而对于小型数据集(所含的“样本和特征”数组数目不足一亿)而言,使用 Estimator 实现近似训练形式的速度甚至可能比精确训练更慢。

TF-DF 还支持不同类型的“近似”树训练。我们建议您使用精确训练法,并选择使用大型数据集测试近似训练。

推理

Estimator 使用自上而下的树路由算法运行模型推理。TF-DF 使用 QuickScorer 算法的扩展程序。

虽然两种算法返回的结果完全相同,但自上而下的算法效率较低,因为这种算法的计算量会超出分支预测并导致缓存未命中。对于同一模型,TF-DF 的推理速度通常可提升 10 倍。

TF-DF 可为延迟关键应用程序提供 C++ API。其推理时间约为每核心每样本 1 微秒。与 TF SavedModel 推理相比,这通常可将速度提升 50 至 1000 倍(对小型批次的效果更佳)。

多头模型

Estimator 支持多头模型(即输出多种预测的模型)。目前,TF-DF 无法直接支持多头模型,但是借助 Keras Functional API,TF-DF 可以将多个并行训练的 TF-DF 模型组成一个多头模型。

了解详情

您可以访问此网址,详细了解 TensorFlow 决策森林。

如果您是首次接触该内容库,我们建议您从初学者示例开始。经验丰富的 TensorFlow 用户可以访问此指南,详细了解有关在 TensorFlow 中使用决策森林和神经网络的区别要点,包括如何配置训练流水线和关于数据集 I/O 的提示。

您还可以仔细阅读Estimator 迁移到 Keras API,了解如何从 Estimator 迁移到 Keras。

原文标题:如何从提升树 Estimator 迁移到 TensorFlow 决策森林

文章出处:【微信公众号:谷歌开发者】欢迎添加关注!文章转载请注明出处。

审核编辑:汤梓红
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • Google
    +关注

    关注

    5

    文章

    1801

    浏览量

    60247
  • 模型
    +关注

    关注

    1

    文章

    3648

    浏览量

    51692
  • tensorflow
    +关注

    关注

    13

    文章

    331

    浏览量

    61841

原文标题:如何从提升树 Estimator 迁移到 TensorFlow 决策森林

文章出处:【微信号:Google_Developers,微信公众号:谷歌开发者】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    构建生命线:云翎智能应急通信自组网如何为森林防火赢得黄金救援时间

    森林火灾的黄金救援时间以分钟计,传统通信常因地形复杂、公网薄弱导致失联。云翎智能应急通信自组网以无中心自组网技术实现3分钟快速建网,通过宽窄带融合传输火场动态,结合北斗短报文补盲,让指挥中心秒级决策
    的头像 发表于 07-14 22:04 321次阅读
    构建生命线:云翎智能应急通信自组网如何为<b class='flag-5'>森林</b>防火赢得黄金救援时间

    无法将Tensorflow Lite模型转换为OpenVINO™格式怎么处理?

    Tensorflow Lite 模型转换为 OpenVINO™ 格式。 遇到的错误: FrontEnd API failed with OpConversionFailure:No translator found for TFLite_Detection_PostP
    发表于 06-25 08:27

    用树莓派搞深度学习?TensorFlow启动!

    RaspberryPi4上运行TensorFlow,但不要期望有奇迹般的表现。如果模型不太复杂,它可以运行您的模型,但无法训练新模型,也无法执行所谓的迁移学习。除了运行您预
    的头像 发表于 03-25 09:33 958次阅读
    用树莓派搞深度学习?<b class='flag-5'>TensorFlow</b>启动!

    TensorFlow模型转换为中间表示 (IR) 时遇到不一致的形状错误怎么解决?

    使用命令转换为 Tensorflow* 模型: mo --input_model ../models/middlebury_d400.pb --input_shape [1,352,704,6
    发表于 03-07 08:20

    使用OpenVINO™ 2020.4.582将自定义TensorFlow 2模型转换为中间表示 (IR)收到错误怎么解决?

    转换自定义 TensorFlow 2 模型 mask_rcnn_inception_resnet_v2_1024x1024_coco17 要 IR 使用模型优化器命令: 注意上面的链接可能无法
    发表于 03-07 07:28

    将YOLOv4模型转换为IR的说明,无法将模型转换为TensorFlow2格式怎么解决?

    遵照 将 YOLOv4 模型转换为 IR 的 说明,但无法将模型转换为 TensorFlow2* 格式。 将 YOLOv4 darknet 转换为 Keras 模型时,收到 Type
    发表于 03-07 07:14

    Tensorflow Efficientdet-d0模型转换为OpenVINO™ IR失败了,怎么解决?

    使用转换命令 mo --saved_model_dir /home/obs-56/effi/saved_model 将 TensorFlow* efficientdet-d0 模型转换为 IR
    发表于 03-06 08:18

    可以使用OpenVINO™工具包将中间表示 (IR) 模型转换为TensorFlow格式吗?

    无法将中间表示 (IR) 模型转换为 TensorFlow* 格式
    发表于 03-06 06:51

    使用各种TensorFlow模型运行模型优化器时遇到错误非法指令怎么解决?

    使用各种 TensorFlow 模型运行模型优化器时遇到 [i]错误非法指令
    发表于 03-05 09:56

    为什么无法使用OpenVINO™模型优化器转换TensorFlow 2.4模型

    已下载 ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8 型号。 使用将模型转换为中间表示 (IR) ssd_support_api_v.2.4.json
    发表于 03-05 09:07

    为什么无法将TensorFlow自定义模型转换为IR格式?

    TensorFlow* 自定义模型转换为 IR 格式: mo --data_type FP16 --saved_model_dir--input_shape (1,150,150,3
    发表于 03-05 07:26

    为什么无法将自定义EfficientDet模型TensorFlow 2转换为中间表示(IR)?

    将自定义 EfficientDet 模型TensorFlow* 2 转换 为 IR 时遇到错误: [ ERROR ] Exception occurred during running replacer \"REPLACEMENT_ID\" ()
    发表于 03-05 06:29

    科技在物联网方面

    给其他设备或云端进行分析和处理。 与通信企业合作:宇科技可能与通信企业展开合作,共同探索5G、6G等新一代通信技术在机器人领域的应用,以提升机器人的通信效率和稳定性,满足物联网场景下大量设备连接和数
    发表于 02-04 06:48

    用Reality AI Tools创建模型

    在第二步采集到的数据基础之上,用Reality AI Tools创建模型
    的头像 发表于 01-22 14:23 2877次阅读
    用Reality AI Tools<b class='flag-5'>创建模型</b>

    外资制造业可利用AI提升决策能力

    运筹优化技术是一种利用数学模型和算法,在有限资源下寻求最佳决策的技术,广泛应用于物流、生产、金融等领域。运筹优化能够帮助解决复杂的优化问题,例如资源分配、路径规划、生产调度等,以提高效率、降低成本或
    的头像 发表于 12-24 10:01 709次阅读