torchvision分类介绍
Torchvision高版本支持各种SOTA的图像分类模型,同时还支持不同数据集分类模型的预训练模型的切换。使用起来十分方便快捷,Pytroch中支持两种迁移学习方式,分别是:
- Finetune模式 基于预训练模型,全链路调优参数 - 冻结特征层模式 这种方式只修改输出层的参数,CNN部分的参数冻结上述两种迁移方式,分别适合大量数据跟少量数据,前一种方式计算跟训练时间会比第二种方式要长点,但是针对大量自定义分类数据效果会比较好。
自定义分类模型修改与训练
加载模型之后,feature_extracting 为true表示冻结模式,否则为finetune模式,相关的代码如下:
def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False以resnet18为例,修改之后的自定义训练代码如下:
model_ft=models.resnet18(pretrained=True) num_ftrs=model_ft.fc.in_features #Herethesizeofeachoutputsampleissetto5. #Alternatively,itcanbegeneralizedtonn.Linear(num_ftrs,len(class_names)). model_ft.fc=nn.Linear(num_ftrs,5) model_ft=model_ft.to(device) criterion=nn.CrossEntropyLoss() #Observethatallparametersarebeingoptimized optimizer_ft=optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9) #DecayLRbyafactorof0.1every7epochs exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) model_ft=train_model(model_ft,criterion,optimizer_ft,exp_lr_scheduler, num_epochs=25)
数据集是flowers-dataset,有五个分类分别是:
daisy dandelion roses sunflowers tulips
全链路调优,迁移学习训练CNN部分的权重参数
Epoch0/24 ---------- trainLoss:1.3993Acc:0.5597 validLoss:1.8571Acc:0.7073 Epoch1/24 ---------- trainLoss:1.0903Acc:0.6580 validLoss:0.6150Acc:0.7805 Epoch2/24 ---------- trainLoss:0.9095Acc:0.6991 validLoss:0.4386Acc:0.8049 Epoch3/24 ---------- trainLoss:0.7628Acc:0.7349 validLoss:0.9111Acc:0.7317 Epoch4/24 ---------- trainLoss:0.7107Acc:0.7669 validLoss:0.4854Acc:0.8049 Epoch5/24 ---------- trainLoss:0.6231Acc:0.7793 validLoss:0.6822Acc:0.8049 Epoch6/24 ---------- trainLoss:0.5768Acc:0.8033 validLoss:0.2748Acc:0.8780 Epoch7/24 ---------- trainLoss:0.5448Acc:0.8110 validLoss:0.4440Acc:0.7561 Epoch8/24 ---------- trainLoss:0.5037Acc:0.8170 validLoss:0.2900Acc:0.9268 Epoch9/24 ---------- trainLoss:0.4836Acc:0.8360 validLoss:0.7108Acc:0.7805 Epoch10/24 ---------- trainLoss:0.4663Acc:0.8369 validLoss:0.5868Acc:0.8049 Epoch11/24 ---------- trainLoss:0.4276Acc:0.8504 validLoss:0.6998Acc:0.8293 Epoch12/24 ---------- trainLoss:0.4299Acc:0.8529 validLoss:0.6449Acc:0.8049 Epoch13/24 ---------- trainLoss:0.4256Acc:0.8567 validLoss:0.7897Acc:0.7805 Epoch14/24 ---------- trainLoss:0.4062Acc:0.8559 validLoss:0.5855Acc:0.8293 Epoch15/24 ---------- trainLoss:0.4030Acc:0.8545 validLoss:0.7336Acc:0.7805 Epoch16/24 ---------- trainLoss:0.3786Acc:0.8730 validLoss:1.0429Acc:0.7561 Epoch17/24 ---------- trainLoss:0.3699Acc:0.8763 validLoss:0.4549Acc:0.8293 Epoch18/24 ---------- trainLoss:0.3394Acc:0.8788 validLoss:0.2828Acc:0.9024 Epoch19/24 ---------- trainLoss:0.3300Acc:0.8834 validLoss:0.6766Acc:0.8537 Epoch20/24 ---------- trainLoss:0.3136Acc:0.8906 validLoss:0.5893Acc:0.8537 Epoch21/24 ---------- trainLoss:0.3110Acc:0.8901 validLoss:0.4909Acc:0.8537 Epoch22/24 ---------- trainLoss:0.3141Acc:0.8931 validLoss:0.3930Acc:0.9024 Epoch23/24 ---------- trainLoss:0.3106Acc:0.8887 validLoss:0.3079Acc:0.9024 Epoch24/24 ---------- trainLoss:0.3143Acc:0.8923 validLoss:0.5122Acc:0.8049 Trainingcompletein25m34s BestvalAcc:0.926829
冻结CNN部分,只训练全连接分类权重
Paramstolearn: fc.weight fc.bias Epoch0/24 ---------- trainLoss:1.0217Acc:0.6465 validLoss:1.5317Acc:0.8049 Epoch1/24 ---------- trainLoss:0.9569Acc:0.6947 validLoss:1.2450Acc:0.6829 Epoch2/24 ---------- trainLoss:1.0280Acc:0.6999 validLoss:1.5677Acc:0.7805 Epoch3/24 ---------- trainLoss:0.8344Acc:0.7426 validLoss:1.1053Acc:0.7317 Epoch4/24 ---------- trainLoss:0.9110Acc:0.7250 validLoss:1.1148Acc:0.7561 Epoch5/24 ---------- trainLoss:0.9049Acc:0.7346 validLoss:1.1541Acc:0.6341 Epoch6/24 ---------- trainLoss:0.8538Acc:0.7465 validLoss:1.4098Acc:0.8293 Epoch7/24 ---------- trainLoss:0.9041Acc:0.7349 validLoss:0.9604Acc:0.7561 Epoch8/24 ---------- trainLoss:0.8885Acc:0.7468 validLoss:1.2603Acc:0.7561 Epoch9/24 ---------- trainLoss:0.9257Acc:0.7333 validLoss:1.0751Acc:0.7561 Epoch10/24 ---------- trainLoss:0.8637Acc:0.7492 validLoss:0.9748Acc:0.7317 Epoch11/24 ---------- trainLoss:0.8686Acc:0.7517 validLoss:1.0194Acc:0.8049 Epoch12/24 ---------- trainLoss:0.8492Acc:0.7572 validLoss:1.0378Acc:0.7317 Epoch13/24 ---------- trainLoss:0.8773Acc:0.7432 validLoss:0.7224Acc:0.8049 Epoch14/24 ---------- trainLoss:0.8919Acc:0.7473 validLoss:1.3564Acc:0.7805 Epoch15/24 ---------- trainLoss:0.8634Acc:0.7490 validLoss:0.7822Acc:0.7805 Epoch16/24 ---------- trainLoss:0.8069Acc:0.7644 validLoss:1.4132Acc:0.7561 Epoch17/24 ---------- trainLoss:0.8589Acc:0.7492 validLoss:0.9812Acc:0.8049 Epoch18/24 ---------- trainLoss:0.7677Acc:0.7688 validLoss:0.7176Acc:0.8293 Epoch19/24 ---------- trainLoss:0.8044Acc:0.7514 validLoss:1.4486Acc:0.7561 Epoch20/24 ---------- trainLoss:0.7916Acc:0.7564 validLoss:1.0575Acc:0.8049 Epoch21/24 ---------- trainLoss:0.7922Acc:0.7647 validLoss:1.0406Acc:0.7805 Epoch22/24 ---------- trainLoss:0.8187Acc:0.7647 validLoss:1.0965Acc:0.7561 Epoch23/24 ---------- trainLoss:0.8443Acc:0.7503 validLoss:1.6163Acc:0.7317 Epoch24/24 ---------- trainLoss:0.8165Acc:0.7583 validLoss:1.1680Acc:0.7561 Trainingcompletein20m7s BestvalAcc:0.829268
测试结果:
零代码训练演示
我已经完成torchvision中分类模型自定义数据集迁移学习的代码封装与开发,支持基于收集到的数据集,零代码训练,生成模型。图示如下:
审核编辑:彭静
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。
举报投诉
-
数据
+关注
关注
8文章
6515浏览量
87621 -
模型
+关注
关注
1文章
2709浏览量
47716 -
迁移学习
+关注
关注
0文章
72浏览量
5503
原文标题:tochvision轻松支持十种图像分类模型迁移学习
文章出处:【微信号:CVSCHOOL,微信公众号:OpenCV学堂】欢迎添加关注!文章转载请注明出处。
发布评论请先 登录
相关推荐
SQL语句的两种嵌套方式
一般情况下,SQL语句是嵌套在宿主语言(如C语言)中的。有两种嵌套方式:1.调用层接口(CLI):提供一些库,库中的函数和方法实现SQL的调用2.直接嵌套SQL:在代码中嵌套SQL语句,提交给预处理器,将SQL语句转换成对宿主语言有意义的内容,如调用库中的函数和方法代替S
发表于 05-23 08:51
linux配置mysql的两种方式
两种方式:a、$ find / -name mysql–print 查看是否有mysql文件夹b、$ netstat -a –n 查看是否打开3306端口
发表于 07-26 07:46
处理器与外部通信的两种方式
处理器与外部通信的两种方式并行通信数据各个位同时传输,速度快,占用引脚资源多串行通信数据按位顺序传输,占用引脚资源少,速度相对比较慢1.按照数据传送方向可以分为:单工:数据传输只支持在一个方向
发表于 08-18 08:06
SQL语言的两种使用方式
SQL语言的两种使用方式在终端交互方式下使用,称为交互式SQL嵌入在高级语言的程序中使用,称为嵌入式SQL―高级语言如C、Java等,称为宿主语言嵌入式SQL的实现方式源程序(用主语言
发表于 12-20 06:51
vnc和xrdp两种远程连接的方式
[zju嵌入式]树莓派之远程桌面 之前两篇介绍了通过串口和ssh登陆到树莓派的方法,这两种方式的有点在于连接方面,响应速度快,但是也有不够直观的缺点,没办法看到图形界面.在这篇博文中,笔者将介绍vnc和xrdp
发表于 12-24 07:54
分享一种智能网卡对热迁移支持的新思路
的数据面必须实现virtio(一种虚拟设备)的数据面。设备的控制面可以按照厂商自定义的格式实现,vDPA会协助完成virtio控制面命令到厂商控制面命令的转换。目前实现vDPA的框架有两种方式,一
发表于 07-05 14:46
电流检测两种方式
电流检测两种方式高端检测既然会使得放大器承受较高的共模电压,那“都采取高侧检测”这句话岂不是自相矛盾怎么理解负载脚底不稳?低端检测会影响到GND电平的稳定性吗?帮忙讲下这四个电路、高端检测2,低端检测2具体的工作原理,或输出公式的推导过程
发表于 03-06 18:43
评论