0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

掌握PyTorch图片分类的简明教程

WpOh_rgznai100 来源:lq 2019-07-18 15:24 次阅读

1.引文

深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分。如果大家还掌握不了使用开源的网络进行训练,再慢慢去模型调优,很难取得较好的成绩。

我们在[PyTorch小试牛刀]实战六·准备自己的数据集用于训练讲解了如何制作自己的数据集用于训练,这个教程在此基础上,进行训练与应用。

(实战六链接:

https://blog.csdn.net/xiaosongshine/article/details/85225873)

2.数据介绍

数据下载地址:

https://download.csdn.net/download/xiaosongshine/11128410

这次的实战使用的数据是交通标志数据集,共有62类交通标志。其中训练集数据有4572张照片(每个类别大概七十个),测试数据集有2520张照片(每个类别大概40个)。数据包含两个子目录分别train与test:

为什么还需要测试数据集呢?这个测试数据集不会拿来训练,是用来进行模型的评估与调优。

train与test每个文件夹里又有62个子文件夹,每个类别在同一个文件夹内:

我从中打开一个文件间,把里面图片展示出来:

其中每张照片都类似下面的例子,100*100*3的大小。100是照片的照片的长和宽,3是什么呢?这其实是照片的色彩通道数目,RGB。彩色照片存储在计算机里就是以三维数组的形式。我们送入网络的也是这些数组。

3.网络构建

1.导入Python包,定义一些参数

1importtorchast 2importtorchvisionastv 3importos 4importtime 5importnumpyasnp 6fromtqdmimporttqdm 7 8 9classDefaultConfigs(object):1011data_dir="./traffic-sign/"12data_list=["train","test"]1314lr=0.00115epochs=1016num_classes=6217image_size=22418batch_size=4019channels=320gpu="0"21train_len=457222test_len=252023use_gpu=t.cuda.is_available()2425config=DefaultConfigs()

2.数据准备,采用PyTorch提供的读取方式

注意一点Train数据需要进行随机裁剪,Test数据不要进行裁剪了

1normalize=tv.transforms.Normalize(mean=[0.485,0.456,0.406], 2std=[0.229,0.224,0.225] 3) 4 5transform={ 6config.data_list[0]:tv.transforms.Compose( 7[tv.transforms.Resize([224,224]),tv.transforms.CenterCrop([224,224]), 8tv.transforms.ToTensor(),normalize]#tv.transforms.Resize用于重设图片大小 9),10config.data_list[1]:tv.transforms.Compose(11[tv.transforms.Resize([224,224]),tv.transforms.ToTensor(),normalize]12)13}1415datasets={16x:tv.datasets.ImageFolder(root=os.path.join(config.data_dir,x),transform=transform[x])17forxinconfig.data_list18}1920dataloader={21x:t.utils.data.DataLoader(dataset=datasets[x],22batch_size=config.batch_size,23shuffle=True24)25forxinconfig.data_list26}

3.构建网络模型(使用resnet18进行迁移学习,训练参数为最后一个全连接层 t.nn.Linear(512,num_classes))

1defget_model(num_classes): 2 3model=tv.models.resnet18(pretrained=True) 4forparmainmodel.parameters(): 5parma.requires_grad=False 6model.fc=t.nn.Sequential( 7t.nn.Dropout(p=0.3), 8t.nn.Linear(512,num_classes) 9)10return(model)

如果电脑硬件支持,可以把下述代码屏蔽,则训练整个网络,最终准确率会上升,训练数据会变慢。

1forparmainmodel.parameters():2parma.requires_grad=False

模型输出

1ResNet( 2(conv1):Conv2d(3,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False) 3(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True) 4(relu):ReLU(inplace) 5(maxpool):MaxPool2d(kernel_size=3,stride=2,padding=1,dilation=1,ceil_mode=False) 6(layer1):Sequential( 7(0):BasicBlock( 8(conv1):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False) 9(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)10(relu):ReLU(inplace)11(conv2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)12(bn2):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)13)14(1):BasicBlock(15(conv1):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)16(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)17(relu):ReLU(inplace)18(conv2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)19(bn2):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)20)21)22(layer2):Sequential(23(0):BasicBlock(24(conv1):Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)25(bn1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)26(relu):ReLU(inplace)27(conv2):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)28(bn2):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)29(downsample):Sequential(30(0):Conv2d(64,128,kernel_size=(1,1),stride=(2,2),bias=False)31(1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)32)33)34(1):BasicBlock(35(conv1):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)36(bn1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)37(relu):ReLU(inplace)38(conv2):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)39(bn2):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)40)41)42(layer3):Sequential(43(0):BasicBlock(44(conv1):Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)45(bn1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)46(relu):ReLU(inplace)47(conv2):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)48(bn2):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)49(downsample):Sequential(50(0):Conv2d(128,256,kernel_size=(1,1),stride=(2,2),bias=False)51(1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)52)53)54(1):BasicBlock(55(conv1):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)56(bn1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)57(relu):ReLU(inplace)58(conv2):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)59(bn2):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)60)61)62(layer4):Sequential(63(0):BasicBlock(64(conv1):Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)65(bn1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)66(relu):ReLU(inplace)67(conv2):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)68(bn2):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)69(downsample):Sequential(70(0):Conv2d(256,512,kernel_size=(1,1),stride=(2,2),bias=False)71(1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)72)73)74(1):BasicBlock(75(conv1):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)76(bn1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)77(relu):ReLU(inplace)78(conv2):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)79(bn2):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)80)81)82(avgpool):AvgPool2d(kernel_size=7,stride=1,padding=0)83(fc):Sequential(84(0):Dropout(p=0.3)85(1):Linear(in_features=512,out_features=62,bias=True)86)87)

4.训练模型(支持自动GPU加速)

1deftrain(epochs): 2 3model=get_model(config.num_classes) 4print(model) 5loss_f=t.nn.CrossEntropyLoss() 6if(config.use_gpu): 7model=model.cuda() 8loss_f=loss_f.cuda() 910opt=t.optim.Adam(model.fc.parameters(),lr=config.lr)11time_start=time.time()1213forepochinrange(epochs):14train_loss=[]15train_acc=[]16test_loss=[]17test_acc=[]18model.train(True)19print("Epoch{}/{}".format(epoch+1,epochs))20forbatch,datasintqdm(enumerate(iter(dataloader["train"]))):21x,y=datas22if(config.use_gpu):23x,y=x.cuda(),y.cuda()24y_=model(x)25#print(x.shape,y.shape,y_.shape)26_,pre_y_=t.max(y_,1)27pre_y=y28#print(y_.shape)29loss=loss_f(y_,pre_y)30#print(y_.shape)31acc=t.sum(pre_y_==pre_y)3233loss.backward()34opt.step()35opt.zero_grad()36if(config.use_gpu):37loss=loss.cpu()38acc=acc.cpu()39train_loss.append(loss.data)40train_acc.append(acc)41#if((batch+1)%5==0):42time_end=time.time()43print("Batch{},Trainloss:{:.4f},Trainacc:{:.4f},Time:{}"\44.format(batch+1,np.mean(train_loss)/config.batch_size,np.mean(train_acc)/config.batch_size,(time_end-time_start)))45time_start=time.time()4647model.train(False)48forbatch,datasintqdm(enumerate(iter(dataloader["test"]))):49x,y=datas50if(config.use_gpu):51x,y=x.cuda(),y.cuda()52y_=model(x)53#print(x.shape,y.shape,y_.shape)54_,pre_y_=t.max(y_,1)55pre_y=y56#print(y_.shape)57loss=loss_f(y_,pre_y)58acc=t.sum(pre_y_==pre_y)5960if(config.use_gpu):61loss=loss.cpu()62acc=acc.cpu()6364test_loss.append(loss.data)65test_acc.append(acc)66print("Batch{},Testloss:{:.4f},Testacc:{:.4f}".format(batch+1,np.mean(test_loss)/config.batch_size,np.mean(test_acc)/config.batch_size))6768t.save(model,str(epoch+1)+"ttmodel.pkl")69707172if__name__=="__main__":73train(config.epochs)

训练结果如下:

1Epoch1/10 2115it[00:48,2.63it/s] 3Batch115,Trainloss:0.0590,Trainacc:0.4635,Time:48.985504150390625 463it[00:24,2.62it/s] 5Batch63,Testloss:0.0374,Testacc:0.6790,Time:24.648272275924683 6Epoch2/10 7115it[00:45,3.22it/s] 8Batch115,Trainloss:0.0271,Trainacc:0.7576,Time:45.68823838233948 963it[00:23,2.62it/s]10Batch63,Testloss:0.0255,Testacc:0.7524,Time:23.27178287506103511Epoch3/1012115it[00:45,3.19it/s]13Batch115,Trainloss:0.0181,Trainacc:0.8300,Time:45.926485061645511463it[00:23,2.60it/s]15Batch63,Testloss:0.0212,Testacc:0.7861,Time:23.8078927993774416Epoch4/1017115it[00:45,3.28it/s]18Batch115,Trainloss:0.0138,Trainacc:0.8767,Time:45.275250196456911963it[00:23,2.57it/s]20Batch63,Testloss:0.0173,Testacc:0.8385,Time:23.73632144927978521Epoch5/1022115it[00:44,3.22it/s]23Batch115,Trainloss:0.0112,Trainacc:0.8950,Time:44.9836382865905762463it[00:22,2.69it/s]25Batch63,Testloss:0.0156,Testacc:0.8520,Time:22.79007434844970726Epoch6/1027115it[00:44,3.19it/s]28Batch115,Trainloss:0.0095,Trainacc:0.9159,Time:45.104269504547122963it[00:22,2.77it/s]30Batch63,Testloss:0.0158,Testacc:0.8214,Time:22.8041245937347431Epoch7/1032115it[00:45,2.95it/s]33Batch115,Trainloss:0.0081,Trainacc:0.9280,Time:45.304390430450443463it[00:23,2.66it/s]35Batch63,Testloss:0.0139,Testacc:0.8528,Time:23.12237954139709536Epoch8/1037115it[00:44,3.23it/s]38Batch115,Trainloss:0.0073,Trainacc:0.9300,Time:44.3047628402709963963it[00:22,2.74it/s]40Batch63,Testloss:0.0142,Testacc:0.8496,Time:22.80183553695678741Epoch9/1042115it[00:43,3.19it/s]43Batch115,Trainloss:0.0068,Trainacc:0.9361,Time:44.084140300750734463it[00:23,2.44it/s]45Batch63,Testloss:0.0142,Testacc:0.8437,Time:23.60441923141479546Epoch10/1047115it[00:46,3.12it/s]48Batch115,Trainloss:0.0063,Trainacc:0.9337,Time:46.765970468521124963it[00:24,2.65it/s]50Batch63,Testloss:0.0130,Testacc:0.8591,Time:24.64351773262024

训练10个Epoch,测试集准确率可以到达0.86,已经达到不错效果。通过修改参数,增加训练,可以达到更高的准确率。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据集
    +关注

    关注

    4

    文章

    1178

    浏览量

    24351
  • pytorch
    +关注

    关注

    2

    文章

    761

    浏览量

    12835

原文标题:实战:掌握PyTorch图片分类的简明教程 | 附完整代码

文章出处:【微信号:rgznai100,微信公众号:rgznai100】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    Protel DXP简明教

    Protel DXP简明教程电路设计自动化( Electronic Design Automation ) EDA 指的就是将电路设计中各种工作交由计算机来协助完成。如电路图( Schematic
    发表于 12-09 15:13

    DSP中断设置简明教

    DSP中断设置简明教程 from hellodsp
    发表于 10-24 20:51

    proteus简明教程下载

    本帖最后由 tzshlyt 于 2012-11-24 19:10 编辑 proteus简明教程电驴下载
    发表于 11-24 19:06

    protel简明教

    protel简明教
    发表于 01-13 13:47

    vivado简明教

    本帖最后由 burkfly 于 2014-1-14 22:26 编辑 vivado简明教程,初学者有用!
    发表于 01-14 22:22

    PROTEL简明教

    简明教程希望对你有用
    发表于 03-27 11:11

    Vivado 简明教

    Vivado 简明教
    发表于 05-07 11:25

    ADS版图导入、编辑、仿真简明教

    ADS版图导入、编辑、仿真简明教
    发表于 09-12 16:10 0次下载

    ZEMAX光学辅助设计简明教

    ZEMAX光学辅助设计简明教程 ZEMAX光学辅助设计简明教
    发表于 10-30 17:57 0次下载

    电工学简明教程习题+答案

    电工学简明教程习题+答案高清pdf版本电工学简明教程习题+答案高清pdf版本
    发表于 02-25 14:13 16次下载

    Protel99简明教

    本文档详细的介绍了Protel99使用的简明教
    发表于 08-30 17:02 0次下载

    Altium-Designer-10简明教

    Altium-Designer-10简明教
    发表于 12-16 22:13 0次下载

    基于DSP中断设置简明教

    基于DSP中断设置简明教
    发表于 10-23 14:28 5次下载
    基于DSP中断设置<b class='flag-5'>简明教</b>程

    电磁兼容简明教程(4)​共模干扰与差模干扰

    电磁兼容简明教程(4)​共模干扰与差模干扰
    的头像 发表于 12-05 15:04 377次阅读
    电磁兼容<b class='flag-5'>简明教</b>程(4)​共模干扰与差模干扰

    电磁兼容简明教程(1)

    电磁兼容简明教程(1)
    的头像 发表于 12-05 16:23 278次阅读
    电磁兼容<b class='flag-5'>简明教</b>程(1)