0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

那些年在pytorch上踩过的坑

jf_78858299 来源:天宏NLP 作者:tianhongzxy 2023-02-22 14:18 次阅读

今天又发现了一个pytorch的小坑,给大家分享一下。手上两份同一模型的代码,一份用tensorflow写的,另一份是我拿pytorch写的,模型架构一模一样,预处理数据的逻辑也一模一样,测试发现模型推理的速度也差不多。一份预处理代码是为pytorch模型写的,用到的库是torch,另一份是为tensorflow写的,用到的是numpy。在训练时,每个epoch耗时居然差距非常大,pytorch的代码在140w条数据上训练每轮耗时约45min,而tensorflow版的代码耗时仅约12min。

我把代码看了又看,百思不得其解,预处理的代码比较复杂,都包含两个for循环,pytorch版代码我把更多的预处理步骤放到了Dataset里,这样训练时加载每个batch后,再要处理的步骤就更少了,速度也应该更快,而tensorflow版代码的for循环里预处理的步骤明明更多,怎么会速度比我的代码还快呢?然而,经过我的测试发现,从加载每个batch的数据进来开始,经过预处理,直到输入到模型做计算前,两者的耗时差了约7~8倍。最后发现问题出在对pytorch的tensor进行了频繁的索引操作。

下面做个实验给大家直观体验一下,对tensor做索引和对array做索引的速度差距有多大,tensorarray都是大小(1000x1000)的二维数组。

Pytorch(version==1.4.1)索引1000000次耗时:3.51秒

图片

Numpy索引1000000次耗时:0.43秒

图片

我还特意对比了一下对TensorFlow的tensor做索引的耗时

TensorFlow(version==2.1.0)索引1000000次耗时:118.89秒

图片

由此可见tensorarray的索引速度至少差距在10倍,不过这也在情理之中,毕竟tensor要比array“重”得多。因此在使用pytorch和tensorflow时,频繁需要索引的操作一定要先把tensor转换为numpy.array来做!

除此之外,与其对二维数组进行索引,不如将其展平为一维数组,算上展平的时间,速度还会有不少提升。

Pytorch从3.51秒降到了1.94秒

图片

Numpy从0.43秒降到了0.29秒

图片

如果在训练和数据预处理过程中发现自己的代码跑起来速度非常慢,记得看一看有没有对tensor做太多次索引,如果有的话,要把它转为numpy.array,还有,尽量把二维、三维的索引变成一维的索引,这些都能加快你训练模型的速度。

PS:最后我的代码终于训练一轮也只需要不到12min了,后来又找了点加速的办法,把训练一轮的时间控制到了9min以内,这些就放在以后再写吧~

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 代码
    +关注

    关注

    30

    文章

    4552

    浏览量

    66642
  • tensorflow
    +关注

    关注

    13

    文章

    313

    浏览量

    60241
  • pytorch
    +关注

    关注

    2

    文章

    759

    浏览量

    12822
收藏 人收藏

    评论

    相关推荐

    使用STM32采集电池电压那些

    本文来解析一个盆友在使用STM32采集电池电压。以STM32F4 的ADC属于逐次逼近SAR 型ADC为例进行分析,参考STM32F405xxDatasheet,对于如何编写ADC程序就不做描述了。
    发表于 03-01 07:39

    开发STM32 USB HID

    记录一下 开发STM32 USB HID一、前言二、代码配置一、前言MCU: STM32F103C8T6CubeMX: STM32CubeMX 5.3.0二、代码配置引脚配置时钟树配置我
    发表于 08-24 07:15

    使用树莓派搭建stm32开发环境以及碰到的问题

    使用树莓派搭建stm32开发环境了很多,下面主要是记录一下,以及碰到的问题。##开发方式的选择1.使用Eclipse+GDB+O
    发表于 08-24 07:47

    有没有关于STM32入门经验分享

    有没有关于STM32入门经验分享
    发表于 10-13 06:52

    NodeMCU开发板经历分享

    写在前面今天入手了一个NodeMCU的板子,准备学习一下物联网相关的知识。不过由于博主学艺不精,在第一步烧写固件了,所以就想着把自己的
    发表于 11-01 07:55

    Linux学习过程与如何解决

    Linux记录记录Linux学习过程与如何解决
    发表于 11-04 08:44

    移植debian系统

    基本的linux系统,板子的交叉编译器是arm-linux-gnueabihf-gcc,这给我带来了不少的麻烦,以至于想重新移植一下debian系统。ok,转入正题,说说这两天我吧。首先...
    发表于 12-14 08:42

    STM32编程常有哪些?

    STM32编程常有哪些?
    发表于 12-17 06:15

    使用MDK5时出现的一些error分享

    使用MDK5时出现的一些error分享
    发表于 12-17 07:49

    记录写SAM4S的bootloader所

    记录写SAM4S的bootloader所
    发表于 01-24 07:16

    Arduino-IDE配置ESP32开发环境的正确方式

    Arduino-IDE配置ESP32-CAM开发环境那些Arduino-IDE配置ESP32开发环境
    发表于 01-25 07:40

    关于RK1808板子调试过程记录

    关于RK1808板子调试过程记录
    发表于 02-16 06:38

    STM32G070CB cubemx串口调试哪些

    使用G070CB时写的中断程序是怎样的?STM32G070CB cubemx串口调试哪些呢?
    发表于 02-18 06:08

    【国民技术N32项目移植】汇总一下我那些

    【国民技术N32项目移植】汇总一下我那些国民技术与电子发烧友联合举办的N32 MCU移植挑战赛,从10月份开始报名,到现在已经持续好几个月了,现在马上就接近最后交作品的日期了,
    发表于 02-28 16:42

    那些年在pytorch上过的当

    最近在修改上一个同事加载和预处理数据的代码,原版的代码使用tf1.4.1写的,数据加载也是完全就是for循环读取+预处理,每读入并预处理好一个batch就返回丢给模型训练,如此往复,我觉得速度实在太慢了,而且我新写的代码都是基于pytorch,虽然预处理的过程很复杂,我还是下决心自己改写。
    的头像 发表于 02-22 14:19 326次阅读
    <b class='flag-5'>那些</b><b class='flag-5'>年在</b><b class='flag-5'>pytorch</b>上过的当