0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

如何在训练过程中正确地把数据输入给模型

XILINX开发者社区 来源:XILINX开发者社区 作者:XILINX开发者社区 2021-07-01 10:47 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

机器学习中一个常见问题是判定与数据交互的最佳方式。

在本文中,我们将提供一种高效方法,用于完成数据的交互、组织以及最终变换(预处理)。随后,我们将讲解如何在训练过程中正确地把数据输入给模型。

PyTorch 框架将帮助我们实现此目标,我们还将从头开始编写几个类。PyTorch 可提供更完整的原生类,但创建我们自己的类可帮助我们加速学习。

第 1 部分:原始数据和数据集

首先我们把尚未经过组织的所有样本称为“原始数据”。

把“数据集”定义为现成可用的数据,即含标签以及基本函数接口(以便于使用原始数据信息)的原始数据。

此处我们使用一种简单的原始数据形式:1 个包含图像和标签的文件夹。

但此方法可扩展至任意性质的样本(可以是图片、录音、视频等)以及包含标签的文件。

标签文件中的每一行都用于描述 1 个样本和相关标签,格式如下:

file_sample_1 label1

file_sample_2 label2

file_sample_3 label3

(。。。)

当能够完成一些基本信息查询(已有样本数量、返回特定编号的样本、预处理每个样本等)时,说明我们已从原始数据集创建了 1 个数据集。

此方法基于面向对象编程以及创建用于数据处理的 “类”。

对于一组简单的图像和标签而言,此方法可能看上去略显杀鸡用牛刀(实际上,此用例通常是通过创建分别用于训练、验证和测试的独立文件夹来进行处理的)。但如果要选择标准交互方法,则此方法将来可复用于多种不同用例,以节省时间。

Python 中处理数据

在 Python 中所有一切都是对象:整数、列表、字典都是如此。

构建含标准属性和方法的“数据集”对象的原因多种多样。我认为,代码的精致要求就足以合理化这一选择,但我理解这是品味的问题。可移植性、速度和代码模块化可能是最重要的原因。

在许多示例以及编码书籍中,我发现了面向对象的编码(尤以类为甚)的其它有趣的功能和优势,总结如下:

• 类可提供继承

• 继承可提供复用

• 继承可提供数据类型扩展

• 继承支持多态现象

• 继承是面向对象的编码的特有功能

■输入 [1]:

import torch

from torchvision import transforms

to_tensor = transforms.ToTensor()

from collections import namedtuple

import functools

import copy

import csv

from PIL import Image

from matplotlib import pyplot as plt

import numpy as np

import os

import datetime

import torch.optim as optim

在我们的示例中,所有原始样本都存储在文件夹中。此文件夹的地址在 raw_data_path 变量中声明。

■输入 [2]:

raw_data_path = ‘。/raw_data/data_images’

构建模块

数据集接口需要一些函数和类。数据集本身就是一个对象,因此我们将创建 MyDataset 类来包含所有重要函数和变量。

首先,我们需要读取标签文件,然后可对样本在其原始格式(此处为 PIL 图像)以及最终的张量格式应用某些变换。

我们需要使用以下函数来读取 1 次标签文件,然后创建包含所有样本名称和标签的元组。

内存中缓存可提升性能,但如果标签文件发生更改,请务必更新缓存内容。

■ 输入 [113]:

DataInfoTuple = namedtuple(‘Sample’,‘SampleName, SampleLabel’)

def myFunc(e):

return e.SampleLabel

# in memory caching decorator: ref https://dbader.org/blog/python-memoization

@functools.lru_cache(1)

def getSampleInfoList(raw_data_path):

sample_list = []

with open(str(raw_data_path) + ‘/labels.txt’, mode = ‘r’) as f:

reader = csv.reader(f, delimiter = ‘ ’)

for i, row in enumerate(reader):

imgname = row[0]

label = int(row[1])

sample_list.append(DataInfoTuple(imgname, label))

sample_list.sort(reverse=False, key=myFunc)

# print(“DataInfoTouple: samples list length = {}”.format(len(sample_list)))

return sample_list

如需直接变换 PIL 图像,那么以下类很实用。

该类仅含 1 种方法:resize。resize 方法能够改变 PIL 图像的原始大小,并对其进行重新采样。如需其它预处理(翻转、剪切、旋转等),需在此类种添加方法。

当 PIL 图像完成预处理后,即可将其转换为张量。此外还可对张量执行进一步的处理步骤。

在以下示例种,可以看到这两种变换:

■ 输入 [4]:

class PilTransform():

“”“generic transformation of a pil image”“”

def resize(self, img, **kwargs):

img = img.resize(( kwargs.get(‘width’), kwargs.get(‘height’)), resample=Image.NEAREST)

return img

# creation of the object pil_transform, having all powers inherited by the class PilTransform

pil_transform = PilTransform()

以下是类 PilTransform 的实操示例:

■ 输入 [5]:

path = raw_data_path + “/img_00000600.JPEG”

print(path)

im1 = Image.open(path, mode=‘r’)

plt.imshow(im1)

。/raw_data/data_images/img_00000600.JPEG

■ 输出 [5]:

《matplotlib.image.AxesImage at 0x121046f5588》

■ 输入 [6]:

im2 = pil_transform.resize(im1, width=128, height=128)

# im2.show()

plt.imshow(im2)

■ 输出 [6]:

《matplotlib.image.AxesImage at 0x12104b36358》

最后,我们定义一个类,用于实现与原始数据的交互。

类 MyDataset 主要提供了 2 个方法:

__len__ 可提供原始样本的数量。

__getitem__ 可使对象变为可迭代类型,并按张量格式返回请求的样本(已完成预处理)。

__getitem__ 步骤:

1) 打开来自文件的样本。

2) 按样本的原始格式对其进行预处理。

3) 将样本变换为张量。

4) 以张量格式对样本进行预处理。

此处添加的预处理仅作为示例。

此类可对张量进行归一化(求平均值和标准差),这有助于加速训练过程。

请注意,PIL 图像由范围 0-255 内的整数值组成,而张量则为范围 0-1 内的浮点数矩阵。

该类会返回包含两个元素的列表:在位置 [0] 返回张量,在位置 [1] 返回包含 SampleName 和 SampleLabel 的命名元组。

■ 输入 [109]:

class MyDataset():

“”“Interface class to raw data, providing the total number of samples in the dataset and a preprocessed item”“”

def __init__(self,

isValSet_bool = None,

raw_data_path = ‘。/’,

SampleInfoList = DataInfoTuple,norm = False,

resize = False,

newsize = (32, 32)

):

self.raw_data_path = raw_data_path

self.SampleInfoList = copy.copy(getSampleInfoList(self.raw_data_path))

self.isValSet_bool = isValSet_bool

self.norm = norm

self.resize = resize

self.newsize = newsize

def __str__(self):

return ‘Path of raw data is ’ + self.raw_data_path + ‘/’ + ‘《raw samples》’

def __len__(self):

return len(self.SampleInfoList)

def __getitem__(self, ndx):

SampleInfoList_tup = self.SampleInfoList[ndx]

filepath = self.raw_data_path + ‘/’ + str(SampleInfoList_tup.SampleName)

if os.path.exists(filepath):

img = Image.open(filepath)

# PIL image preprocess (examples)

#resize

if self.resize:

width, height = img.size

if (width 》= height) & (self.newsize[0] 》= self.newsize[1]):

img = pil_transform.resize(img, width=self.newsize[0], height=self.newsize[1])

elif (width 》= height) & (self.newsize[0] 《 self.newsize[1]):

img = pil_transform.resize(img, width=self.newsize[1], height=self.newsize[0])

elif (width 《 height) & (self.newsize[0] 《= self.newsize[1]):

img = pil_transform.resize(img, width=self.newsize[0], height=self.newsize[1])

elif (width 《 height) & (self.newsize[0] 》 self.newsize[1]):

img = pil_transform.resize(img, width=self.newsize[1], height=self.newsize[0])

else:

print(“ERROR”)

# from pil image to tensor

img_t = to_tensor(img)

# tensor preprocess (examples)

#rotation

ratio = img_t.shape[1]/img_t.shape[2]

if ratio 》 1:

img_t = torch.rot90(img_t, 1, [1, 2])

#normalization requires the knowledge of all tensors

if self.norm:

img_t = normalize(img_t)

#return img_t, SampleInfoList_tup

return img_t, SampleInfoList_tup.SampleLabel

else:

print(‘[WARNING] file {} does not exist’.format(str(SampleInfoList_tup.SampleName)))

return None

编辑:jq

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 机器学习
    +关注

    关注

    67

    文章

    8565

    浏览量

    137226
  • 数据集
    +关注

    关注

    4

    文章

    1240

    浏览量

    26261
  • PIL
    PIL
    +关注

    关注

    0

    文章

    19

    浏览量

    8982
  • pytorch
    +关注

    关注

    2

    文章

    813

    浏览量

    14923

原文标题:开发者分享 | 利用 Python 和 PyTorch 处理面向对象的数据集:1. 原始数据和数据集

文章出处:【微信号:gh_2d1c7e2d540e,微信公众号:XILINX开发者社区】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    Edge Impulse 唤醒词模型训练 | 技术集结

    Edgi-Talk开始使用边缘机器学习!目录EdgeImpulse简介创建账号录制数据数据上传数据分割模型训练
    的头像 发表于 04-20 10:05 640次阅读
    Edge Impulse 唤醒词<b class='flag-5'>模型</b><b class='flag-5'>训练</b> | 技术集结

    AI大模型微调企业项目实战课

    成长为该领域的资深专家。通过将企业积累的高质量业务问答对、专业文档输入模型,调整其内部的极小部分参数,就能让模型在保持原有通用能力的基础上,精准掌握企业的特定语感和输出规范。这不仅将
    发表于 04-16 18:48

    【瑞萨AI挑战赛】手写数字识别模型在RA8P1 Titan Board上的部署

    设置为0.02; 损失函数:稀疏交叉熵损失(适配MNIST数据集的整数标签); 评价指标:准确率(Accuracy); 训练轮数(Epoch):10 训练过程中将打印每一轮的准确率情况,同时会将每一轮
    发表于 03-15 20:42

    自动驾驶大模型训练数据有什么具体要求?

    [首发于智驾最前沿微信公众号]想训练出一个可以落地的自动驾驶大模型,不是简单地其提供几张图片,几条规则就可以的,而是需要非常多的多样的、真实的驾驶数据,从而可以让大
    的头像 发表于 12-26 09:32 364次阅读
    自动驾驶大<b class='flag-5'>模型</b>的<b class='flag-5'>训练</b><b class='flag-5'>数据</b>有什么具体要求?

    在Ubuntu20.04系统中训练神经网络模型的一些经验

    , batch_size=512, epochs=20)总结 这个核心算法中的卷积神经网络结构和训练过程,是用来对MNIST手写数字图像进行分类的。模型将图像作为输入,通过卷积和池化层提取图像的特征,然后通过全连接层进行分类预
    发表于 10-22 07:03

    何在vivadoHLS中使用.TLite模型

    本帖欲分享如何在vivadoHLS中使用.TLite模型。在Vivado HLS中导入模型后,需要设置其输入和输出接口以与您的设计进行适配。 1. 在Vivado HLS项目中导入
    发表于 10-22 06:29

    借助NVIDIA Megatron-Core大模型训练框架提高显存使用效率

    随着模型规模迈入百亿、千亿甚至万亿参数级别,如何在有限显存中“塞下”训练任务,对研发和运维团队都是巨大挑战。NVIDIA Megatron-Core 作为流行的大模型
    的头像 发表于 10-21 10:55 1401次阅读
    借助NVIDIA Megatron-Core大<b class='flag-5'>模型</b><b class='flag-5'>训练</b>框架提高显存使用效率

    何在FPGA部署AI模型

    如果你已经在用 MATLAB 做深度学习,那一定知道它的训练和仿真体验非常丝滑。但当模型要真正落地到 FPGA 上时,往往就会卡住:怎么网络结构和权重优雅地搬到硬件里?
    的头像 发表于 09-24 10:00 4637次阅读
    如<b class='flag-5'>何在</b>FPGA部署AI<b class='flag-5'>模型</b>

    何在Ray分布式计算框架下集成NVIDIA Nsight Systems进行GPU性能分析

    在大语言模型的强化学习训练过程中,GPU 性能优化至关重要。随着模型规模不断扩大,如何高效地分析和优化 GPU 性能成为开发者面临的主要挑战之一。
    的头像 发表于 07-23 10:34 2567次阅读
    如<b class='flag-5'>何在</b>Ray分布式计算框架下集成NVIDIA Nsight Systems进行GPU性能分析

    SC11 FP300 MLA算子融合与优化

    ,降低推理过程中需要的KVcache,提升推理效率。MLA对attention过程中的Query也进行了低秩压缩,可以减少训练过程中激活的内存。大模型的推理分为两阶
    的头像 发表于 06-27 14:32 1734次阅读
    SC11 FP300 MLA算子融合与优化

    群晖发布AI模型全流程存储解决方案,破局训练效率与数据孤岛难题

    IDC预测:从2023年每秒产生4.2PB数据,到2028年将激增至12.5PB——AI大模型掀起的数据海啸已席卷而来。企业争相投入千亿参数模型训练
    的头像 发表于 06-25 16:03 835次阅读
    群晖发布AI<b class='flag-5'>模型</b>全流程存储解决方案,破局<b class='flag-5'>训练</b>效率与<b class='flag-5'>数据</b>孤岛难题

    超声波清洗机如何在清洗过程中减少废液和对环境的影响?

    超声波清洗机如何在清洗过程中减少废液和对环境的影响随着环保意识的增强,清洗过程中的废液处理和环境保护变得越来越重要。超声波清洗机作为一种高效的清洗技术,也在不断发展以减少废液生成和对环境的影响。本文
    的头像 发表于 06-16 17:01 818次阅读
    超声波清洗机如<b class='flag-5'>何在</b>清洗<b class='flag-5'>过程中</b>减少废液和对环境的影响?

    瑞芯微模型量化文件构建

    模型是一张图片输入时,量化文件如上图所示。但是我现在想量化deepprivacy人脸匿名模型,他的输入是四个输入。该
    发表于 06-13 09:07

    算力网络的“神经突触”:AI互联技术如何重构分布式训练范式

      电子发烧友网综合报道 随着AI技术迅猛发展,尤其是大型语言模型的兴起,对于算力的需求呈现出爆炸性增长。这不仅推动了智算中心的建设,还对网络互联技术提出了新的挑战。   在AI大模型训练过程中
    的头像 发表于 06-08 08:11 7718次阅读
    算力网络的“神经突触”:AI互联技术如何重构分布式<b class='flag-5'>训练</b>范式

    海思SD3403边缘计算AI数据训练概述

    模型,将模型转化为嵌入式AI模型模型升级AI摄像机,进行AI识别应用。 AI训练模型是不断迭
    发表于 04-28 11:11