0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

如何将自定义图片输入到TensorFlow的训练模型

Dbwd_Imgtec 来源:未知 作者:李倩 2018-08-17 15:57 次阅读

很多正在入门或刚入门TensorFlow机器学习的同学希望能够通过自己指定图片源对模型进行训练,然后识别和分类自己指定的图片。但是,在TensorFlow官方入门教程中,并无明确给出如何把自定义数据输入训练模型的方法。现在,我们就参考官方入门课程《Deep MNIST for Experts》一节的内容(传送门:https://www.tensorflow.org/get_started/mnist/pros),介绍如何将自定义图片输入到TensorFlow的训练模型。

在《Deep MNISTfor Experts》一节的代码中,程序将TensorFlow自带的mnist图片数据集mnist.train.images作为训练输入,将mnist.test.images作为验证输入。当学习了该节内容后,我们会惊叹卷积神经网络的超高识别率,但对于刚开始学习TensorFlow的同学,内心可能会产生一个问号:如何将mnist数据集替换为自己指定的图片源?譬如,我要将图片源改为自己C盘里面的图片,应该怎么调整代码?

我们先看下该节课程中涉及到mnist图片调用的代码:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

batch = mnist.train.next_batch(50)

train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})

train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

对于刚接触TensorFlow的同学,要修改上述代码,可能会较为吃力。我也是经过一番摸索,才成功调用自己的图片集。

要实现输入自定义图片,需要自己先准备好一套图片集。为节省时间,我们把mnist的手写体数字集一张一张地解析出来,存放到自己的本地硬盘,保存为bmp格式,然后再把本地硬盘的手写体图片一张一张地读取出来,组成集合,再输入神经网络。mnist手写体数字集的提取方式详见《如何从TensorFlow的mnist数据集导出手写体数字图片》。

将mnist手写体数字集导出图片到本地后,就可以仿照以下python代码,实现自定义图片的训练:

#!/usr/bin/python3.5

# -*- coding: utf-8 -*-

import os

import numpy as np

import tensorflow as tf

from PIL import Image

# 第一次遍历图片目录是为了获取图片总数

input_count = 0

for i in range(0,10):

dir = './custom_images/%s/' % i # 这里可以改成你自己的图片目录,i为分类标签

for rt, dirs, files in os.walk(dir):

for filename in files:

input_count += 1

# 定义对应维数和各维长度的数组

input_images = np.array([[0]*784 for i in range(input_count)])

input_labels = np.array([[0]*10 for i in range(input_count)])

# 第二次遍历图片目录是为了生成图片数据和标签

index = 0

for i in range(0,10):

dir = './custom_images/%s/' % i # 这里可以改成你自己的图片目录,i为分类标签

for rt, dirs, files in os.walk(dir):

for filename in files:

filename = dir + filename

img = Image.open(filename)

width = img.size[0]

height = img.size[1]

for h in range(0, height):

for w in range(0, width):

# 通过这样的处理,使数字的线条变细,有利于提高识别准确率

if img.getpixel((w, h)) > 230:

input_images[index][w+h*width] = 0

else:

input_images[index][w+h*width] = 1

input_labels[index][i] = 1

index += 1

# 定义输入节点,对应于图片像素值矩阵集合和图片标签(即所代表的数字)

x = tf.placeholder(tf.float32, shape=[None, 784])

y_ = tf.placeholder(tf.float32, shape=[None, 10])

x_image = tf.reshape(x, [-1, 28, 28, 1])

# 定义第一个卷积层的variables和ops

W_conv1 = tf.Variable(tf.truncated_normal([7, 7, 1, 32], stddev=0.1))

b_conv1 = tf.Variable(tf.constant(0.1, shape=[32]))

L1_conv = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')

L1_relu = tf.nn.relu(L1_conv + b_conv1)

L1_pool = tf.nn.max_pool(L1_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

# 定义第二个卷积层的variables和ops

W_conv2 = tf.Variable(tf.truncated_normal([3, 3, 32, 64], stddev=0.1))

b_conv2 = tf.Variable(tf.constant(0.1, shape=[64]))

L2_conv = tf.nn.conv2d(L1_pool, W_conv2, strides=[1, 1, 1, 1], padding='SAME')

L2_relu = tf.nn.relu(L2_conv + b_conv2)

L2_pool = tf.nn.max_pool(L2_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

# 全连接层

W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1))

b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))

h_pool2_flat = tf.reshape(L2_pool, [-1, 7*7*64])

h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# dropout

keep_prob = tf.placeholder(tf.float32)

h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# readout层

W_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))

b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

# 定义优化器和训练op

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))

train_step = tf.train.AdamOptimizer((1e-4)).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

print ("一共读取了 %s 个输入图像, %s 个标签" % (input_count, input_count))

# 设置每次训练op的输入个数和迭代次数,这里为了支持任意图片总数,定义了一个余数remainder,譬如,如果每次训练op的输入个数为60,图片总数为150张,则前面两次各输入60张,最后一次输入30张(余数30)

batch_size = 60

iterations = 100

batches_count = int(input_count / batch_size)

remainder = input_count % batch_size

print ("数据集分成 %s 批, 前面每批 %s 个数据,最后一批 %s 个数据" % (batches_count+1, batch_size, remainder))

# 执行训练迭代

for it in range(iterations):

# 这里的关键是要把输入数组转为np.array

for n in range(batches_count):

train_step.run(feed_dict={x: input_images[n*batch_size:(n+1)*batch_size], y_: input_labels[n*batch_size:(n+1)*batch_size], keep_prob: 0.5})

if remainder > 0:

start_index = batches_count * batch_size;

train_step.run(feed_dict={x: input_images[start_index:input_count-1], y_: input_labels[start_index:input_count-1], keep_prob: 0.5})

# 每完成五次迭代,判断准确度是否已达到100%,达到则退出迭代循环

iterate_accuracy = 0

if it%5 == 0:

iterate_accuracy = accuracy.eval(feed_dict={x: input_images, y_: input_labels, keep_prob: 1.0})

print ('iteration %d: accuracy %s' % (it, iterate_accuracy))

if iterate_accuracy >= 1:

break;

print ('完成训练!')

上述python代码的执行结果截图如下:

对于上述代码中与模型构建相关的代码,请查阅官方《Deep MNIST for Experts》一节的内容进行理解。在本文中,需要重点掌握的是如何将本地图片源整合成为feed_dict可接受的格式。其中最关键的是这两行:

# 定义对应维数和各维长度的数组

input_images = np.array([[0]*784 for i in range(input_count)])

input_labels = np.array([[0]*10 for i in range(input_count)])

它们对应于feed_dict的两个placeholder:

x = tf.placeholder(tf.float32, shape=[None, 784])

y_ = tf.placeholder(tf.float32, shape=[None, 10])

这样一看,是不是很简单?

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4538

    浏览量

    98425
  • 数据集
    +关注

    关注

    4

    文章

    1174

    浏览量

    24285
  • tensorflow
    +关注

    关注

    13

    文章

    313

    浏览量

    60232

原文标题:如何用TensorFlow训练和识别/分类自定义图片

文章出处:【微信号:Imgtec,微信公众号:Imagination Tech】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    基于YOLOv8实现自定义姿态评估模型训练

    Hello大家好,今天给大家分享一下如何基于YOLOv8姿态评估模型,实现在自定义数据集上,完成自定义姿态评估模型训练与推理。
    的头像 发表于 12-25 11:29 1019次阅读
    基于YOLOv8实现<b class='flag-5'>自定义</b>姿态评估<b class='flag-5'>模型</b><b class='flag-5'>训练</b>

    自定义图片按键

    图片按键设置是在自定义控件里面设置,这我是知道的,但我需要如图所示效果:按键正常情况下如图1所示,当鼠标放在按键上时(鼠标不点击),按键就出现图2所示现象,就像菜单图标一样。请各位大神指教,这应该怎么操作啊?谢谢! 图1图2
    发表于 12-27 15:33

    LabVIEW自定义控件

    表示程序内容或者按键功 能的图 片导入 控件 中。我们具体来看一下这样 的程序 是如何 实现的 。这个程序中主要使用了3 个自定义控件:火球,小怪物还有蘑菇按钮,其中的火球是通过迚度条控件实现
    发表于 01-07 10:57

    labview如何在自定义里修改仪表控件的指针?

    本帖最后由 weihk 于 2017-2-23 19:46 编辑 请问,如何将labview中表盘控件的指针修改成自定义形状?例如,用真实仪表盘照片中提取的指针来代替控件中原有指针。 在自定义
    发表于 02-23 18:50

    使用Yocto创建了fsl-image-gui,如何将自定义脚本添加到RootFS?

    您好,我使用 Yocto 创建了 fsl-image-gui,我想将自定义脚本添加到 RootFS。当我使用命令重建图像时,我确实尝试将它添加到“tmp/work
    发表于 03-27 08:32

    1602自定义字符

    1602液晶能够显示自定义字符,能够根据读者的具体情况显示自定义字符。
    发表于 01-20 15:43 1次下载

    RTWconfigurationguide基于模型设计—自定义

    基于模型设计—自定义目标系统配置指南,RTW自动代码生成相关资料。
    发表于 05-17 16:41 3次下载

    16.stm32f10显示自定义图片

    显示自定义图片
    发表于 10-20 16:57 2次下载

    如何在TensorFlow2里使用Keras API创建一个自定义CNN网络?

    http://yann.lecun.com/exdb/mnist/。 在该例程中我们会演示以下的步骤: 使用 TensorFlow2 训练和评估小型自定义卷积神经网络 对浮点模型进行量
    的头像 发表于 04-15 11:36 1879次阅读

    将自定义热敏电阻与Temp-to-Bits系列配合使用

    将自定义热敏电阻与Temp-to-Bits系列配合使用
    发表于 04-18 14:27 7次下载
    <b class='flag-5'>将自定义</b>热敏电阻与Temp-to-Bits系列配合使用

    如何在移动设备上训练和部署自定义目标检测模型

    上,目标检测模型训练和部署的过程: 设备端 ML 学习路径:关于如何在移动设备上,训练和部署自定义目标检测模型的分步教程,无需机器学习专业
    的头像 发表于 08-16 17:09 2884次阅读

    OpenHarmony自定义组件FlowImageLayout

    组件介绍 本示例是OpenHarmony自定义组件FlowImageLayout。 用于将一个图片列表以瀑布流的形式显示出来。 调用方法
    发表于 03-21 10:17 3次下载
    OpenHarmony<b class='flag-5'>自定义</b>组件FlowImageLayout

    自定义视图组件教程案例

    自定义组件 1.自定义组件-particles(粒子效果) 2.自定义组件- pulse(脉冲button效果) 3.自定义组件-progress(progress效果) 4.
    发表于 04-08 10:48 14次下载

    自定义算子开发

    一个完整的自定义算子应用过程包括注册算子、算子实现、含自定义算子模型转换和运行含自定义op模型四个阶段。在大多数情况下,您的
    的头像 发表于 04-07 16:11 1834次阅读
    <b class='flag-5'>自定义</b>算子开发

    labview超快自定义控件制作和普通自定义控件制作

    labview超快自定义控件制作和普通自定义控件制作
    发表于 08-21 10:32 5次下载