0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

Pytroch中支持的两种迁移学习方式

OpenCV学堂 来源:OpenCV学堂 作者:OpenCV学堂 2022-10-09 15:16 次阅读

torchvision分类介绍

Torchvision高版本支持各种SOTA的图像分类模型,同时还支持不同数据集分类模型的预训练模型的切换。使用起来十分方便快捷,Pytroch中支持两种迁移学习方式,分别是:

- Finetune模式
基于预训练模型,全链路调优参数
- 冻结特征层模式
这种方式只修改输出层的参数,CNN部分的参数冻结
上述两种迁移方式,分别适合大量数据跟少量数据,前一种方式计算跟训练时间会比第二种方式要长点,但是针对大量自定义分类数据效果会比较好。

自定义分类模型修改与训练

加载模型之后,feature_extracting 为true表示冻结模式,否则为finetune模式,相关的代码如下:

def set_parameter_requires_grad(model, feature_extracting):
     if feature_extracting:
         for param in model.parameters():
             param.requires_grad = False
以resnet18为例,修改之后的自定义训练代码如下:
model_ft=models.resnet18(pretrained=True)
num_ftrs=model_ft.fc.in_features
#Herethesizeofeachoutputsampleissetto5.
#Alternatively,itcanbegeneralizedtonn.Linear(num_ftrs,len(class_names)).
model_ft.fc=nn.Linear(num_ftrs,5)

model_ft=model_ft.to(device)

criterion=nn.CrossEntropyLoss()

#Observethatallparametersarebeingoptimized
optimizer_ft=optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9)

#DecayLRbyafactorof0.1every7epochs
exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)

model_ft=train_model(model_ft,criterion,optimizer_ft,exp_lr_scheduler,
num_epochs=25)

数据集是flowers-dataset,有五个分类分别是:

daisy
dandelion
roses
sunflowers
tulips

全链路调优,迁移学习训练CNN部分的权重参数

Epoch0/24
----------
trainLoss:1.3993Acc:0.5597
validLoss:1.8571Acc:0.7073
Epoch1/24
----------
trainLoss:1.0903Acc:0.6580
validLoss:0.6150Acc:0.7805
Epoch2/24
----------
trainLoss:0.9095Acc:0.6991
validLoss:0.4386Acc:0.8049
Epoch3/24
----------
trainLoss:0.7628Acc:0.7349
validLoss:0.9111Acc:0.7317
Epoch4/24
----------
trainLoss:0.7107Acc:0.7669
validLoss:0.4854Acc:0.8049
Epoch5/24
----------
trainLoss:0.6231Acc:0.7793
validLoss:0.6822Acc:0.8049
Epoch6/24
----------
trainLoss:0.5768Acc:0.8033
validLoss:0.2748Acc:0.8780
Epoch7/24
----------
trainLoss:0.5448Acc:0.8110
validLoss:0.4440Acc:0.7561
Epoch8/24
----------
trainLoss:0.5037Acc:0.8170
validLoss:0.2900Acc:0.9268
Epoch9/24
----------
trainLoss:0.4836Acc:0.8360
validLoss:0.7108Acc:0.7805
Epoch10/24
----------
trainLoss:0.4663Acc:0.8369
validLoss:0.5868Acc:0.8049
Epoch11/24
----------
trainLoss:0.4276Acc:0.8504
validLoss:0.6998Acc:0.8293
Epoch12/24
----------
trainLoss:0.4299Acc:0.8529
validLoss:0.6449Acc:0.8049
Epoch13/24
----------
trainLoss:0.4256Acc:0.8567
validLoss:0.7897Acc:0.7805
Epoch14/24
----------
trainLoss:0.4062Acc:0.8559
validLoss:0.5855Acc:0.8293
Epoch15/24
----------
trainLoss:0.4030Acc:0.8545
validLoss:0.7336Acc:0.7805
Epoch16/24
----------
trainLoss:0.3786Acc:0.8730
validLoss:1.0429Acc:0.7561
Epoch17/24
----------
trainLoss:0.3699Acc:0.8763
validLoss:0.4549Acc:0.8293
Epoch18/24
----------
trainLoss:0.3394Acc:0.8788
validLoss:0.2828Acc:0.9024
Epoch19/24
----------
trainLoss:0.3300Acc:0.8834
validLoss:0.6766Acc:0.8537
Epoch20/24
----------
trainLoss:0.3136Acc:0.8906
validLoss:0.5893Acc:0.8537
Epoch21/24
----------
trainLoss:0.3110Acc:0.8901
validLoss:0.4909Acc:0.8537
Epoch22/24
----------
trainLoss:0.3141Acc:0.8931
validLoss:0.3930Acc:0.9024
Epoch23/24
----------
trainLoss:0.3106Acc:0.8887
validLoss:0.3079Acc:0.9024
Epoch24/24
----------
trainLoss:0.3143Acc:0.8923
validLoss:0.5122Acc:0.8049
Trainingcompletein25m34s
BestvalAcc:0.926829

冻结CNN部分,只训练全连接分类权重

Paramstolearn:
fc.weight
fc.bias
Epoch0/24
----------
trainLoss:1.0217Acc:0.6465
validLoss:1.5317Acc:0.8049
Epoch1/24
----------
trainLoss:0.9569Acc:0.6947
validLoss:1.2450Acc:0.6829
Epoch2/24
----------
trainLoss:1.0280Acc:0.6999
validLoss:1.5677Acc:0.7805
Epoch3/24
----------
trainLoss:0.8344Acc:0.7426
validLoss:1.1053Acc:0.7317
Epoch4/24
----------
trainLoss:0.9110Acc:0.7250
validLoss:1.1148Acc:0.7561
Epoch5/24
----------
trainLoss:0.9049Acc:0.7346
validLoss:1.1541Acc:0.6341
Epoch6/24
----------
trainLoss:0.8538Acc:0.7465
validLoss:1.4098Acc:0.8293
Epoch7/24
----------
trainLoss:0.9041Acc:0.7349
validLoss:0.9604Acc:0.7561
Epoch8/24
----------
trainLoss:0.8885Acc:0.7468
validLoss:1.2603Acc:0.7561
Epoch9/24
----------
trainLoss:0.9257Acc:0.7333
validLoss:1.0751Acc:0.7561
Epoch10/24
----------
trainLoss:0.8637Acc:0.7492
validLoss:0.9748Acc:0.7317
Epoch11/24
----------
trainLoss:0.8686Acc:0.7517
validLoss:1.0194Acc:0.8049
Epoch12/24
----------
trainLoss:0.8492Acc:0.7572
validLoss:1.0378Acc:0.7317
Epoch13/24
----------
trainLoss:0.8773Acc:0.7432
validLoss:0.7224Acc:0.8049
Epoch14/24
----------
trainLoss:0.8919Acc:0.7473
validLoss:1.3564Acc:0.7805
Epoch15/24
----------
trainLoss:0.8634Acc:0.7490
validLoss:0.7822Acc:0.7805
Epoch16/24
----------
trainLoss:0.8069Acc:0.7644
validLoss:1.4132Acc:0.7561
Epoch17/24
----------
trainLoss:0.8589Acc:0.7492
validLoss:0.9812Acc:0.8049
Epoch18/24
----------
trainLoss:0.7677Acc:0.7688
validLoss:0.7176Acc:0.8293
Epoch19/24
----------
trainLoss:0.8044Acc:0.7514
validLoss:1.4486Acc:0.7561
Epoch20/24
----------
trainLoss:0.7916Acc:0.7564
validLoss:1.0575Acc:0.8049
Epoch21/24
----------
trainLoss:0.7922Acc:0.7647
validLoss:1.0406Acc:0.7805
Epoch22/24
----------
trainLoss:0.8187Acc:0.7647
validLoss:1.0965Acc:0.7561
Epoch23/24
----------
trainLoss:0.8443Acc:0.7503
validLoss:1.6163Acc:0.7317
Epoch24/24
----------
trainLoss:0.8165Acc:0.7583
validLoss:1.1680Acc:0.7561
Trainingcompletein20m7s
BestvalAcc:0.829268

测试结果:

零代码训练演示

我已经完成torchvision中分类模型自定义数据集迁移学习的代码封装与开发,支持基于收集到的数据集,零代码训练,生成模型。图示如下:

96f99d5a-47a0-11ed-a3b6-dac502259ad0.png

审核编辑:彭静
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据
    +关注

    关注

    8

    文章

    6515

    浏览量

    87621
  • 模型
    +关注

    关注

    1

    文章

    2709

    浏览量

    47716
  • 迁移学习
    +关注

    关注

    0

    文章

    72

    浏览量

    5503

原文标题:tochvision轻松支持十种图像分类模型迁移学习

文章出处:【微信号:CVSCHOOL,微信公众号:OpenCV学堂】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    两种采样方式

    两种采样方式.....................................
    发表于 08-08 15:04

    请问小车转向两种方式有什么优缺点?

    我知道的小车转向常见的有两种方式,一是通过舵机控制转向,另一是通过控制个轮子的转速,通过转速差实现转向,这
    发表于 05-21 02:37

    SQL语句的两种嵌套方式

    一般情况下,SQL语句是嵌套在宿主语言(如C语言)中的。有两种嵌套方式:1.调用层接口(CLI):提供一些库,库中的函数和方法实现SQL的调用2.直接嵌套SQL:在代码中嵌套SQL语句,提交给预处理器,将SQL语句转换成对宿主语言有意义的内容,如调用库中的函数和方法代替S
    发表于 05-23 08:51

    linux配置mysql的两种方式

    两种方式:a、$ find / -name mysql–print 查看是否有mysql文件夹b、$ netstat -a –n 查看是否打开3306端口
    发表于 07-26 07:46

    Linux实现输入参数求和的两种方式

    Linux实现输入参数求和(两种方式
    发表于 03-26 11:44

    STemWin两种方式移植完整资料整理

    STemWin两种方式移植完整资料整理,初用者可以一看。
    发表于 08-04 16:04

    编译环境的两种搭建方式

    编译环境的两种搭建方式putty工具的使用winscp工具的使用
    发表于 12-22 08:00

    求大神分享基于FPGA的DDFS与DDWS的两种实现方式

    DDS的基本原理是什么,有什么性能指标?基于FPGA的DDFS与DDWS两种实现方式
    发表于 04-30 06:13

    处理器与外部通信的两种方式

    处理器与外部通信的两种方式并行通信数据各个位同时传输,速度快,占用引脚资源多串行通信数据按位顺序传输,占用引脚资源少,速度相对比较慢1.按照数据传送方向可以分为:单工:数据传输只支持在一个方向
    发表于 08-18 08:06

    串口通信的两种方式

    串口通信由两种方式,第一就是用微机原理课上学过的TX和RX个接口进行通信。不过根据去年的经验,这个板子直接用TX和RX个接口进行通信容
    发表于 08-24 06:59

    SQL语言的两种使用方式

    SQL语言的两种使用方式在终端交互方式下使用,称为交互式SQL嵌入在高级语言的程序中使用,称为嵌入式SQL―高级语言如C、Java等,称为宿主语言嵌入式SQL的实现方式源程序(用主语言
    发表于 12-20 06:51

    vnc和xrdp两种远程连接的方式

    [zju嵌入式]树莓派之远程桌面 之前篇介绍了通过串口和ssh登陆到树莓派的方法,这两种方式的有点在于连接方面,响应速度快,但是也有不够直观的缺点,没办法看到图形界面.在这篇博文中,笔者将介绍vnc和xrdp
    发表于 12-24 07:54

    迁移学习

    经典机器学习算法介绍章节目标:机器学习是人工智能的重要技术之一,详细了解机器学习的原理、机制和方法,为学习深度学习
    发表于 04-21 15:15

    分享一智能网卡对热迁移支持的新思路

    的数据面必须实现virtio(一虚拟设备)的数据面。设备的控制面可以按照厂商自定义的格式实现,vDPA会协助完成virtio控制面命令到厂商控制面命令的转换。目前实现vDPA的框架有两种方式,一
    发表于 07-05 14:46

    电流检测两种方式

    电流检测两种方式高端检测既然会使得放大器承受较高的共模电压,那“都采取高侧检测”这句话岂不是自相矛盾怎么理解负载脚底不稳?低端检测会影响到GND电平的稳定性吗?帮忙讲下这四个电路、高端检测2,低端检测2具体的工作原理,或输出公式的推导过程
    发表于 03-06 18:43