0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

pytorch实现断电继续训练时需要注意的要点

新机器视觉 来源:新机器视觉 作者:新机器视觉 2022-08-22 09:50 次阅读

导读

本文整理了pytorch实现断电继续训练时需要注意的要点,附有代码详解。

最近在尝试用CIFAR10训练分类问题的时候,由于数据集体量比较大,训练的过程中时间比较长,有时候想给停下来,但是停下来了之后就得重新训练,之前师兄让我们学习断点继续训练及继续训练的时候注意epoch的改变等,今天上午给大致整理了一下,不全面仅供参考


Epoch:  9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018  sEpoch:  9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216  sEpoch:  9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398  sEpoch:  9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921  sEpoch:  9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974  sEpoch:  9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034  sEpoch:  9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831  s绝望!!!!!训练到了一定次数发现训练次数少了,或者中途断了又得重新开始训练

一、模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

torch.save主要参数:obj:对象 、f:输出路径

torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu

模型的保存的两种方法:

1、保存整个Module


torch.save(net, path)

2、保存模型参数


state_dict = net.state_dict()torch.save(state_dict , path)

二、模型的训练过程中保存


checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }

网络训练过程中的网络的权重,优化器的权重保存,以及epoch 保存,便于继续训练恢复

在训练过程中,可以根据自己的需要,每多少代,或者多少epoch保存一次网络参数,便于恢复,提高程序的鲁棒性。

checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")torch.save(checkpoint,'./models/checkpoint/ckpt_best_%s.pth'%(str(epoch)))
通过上述的过程可以在训练过程自动在指定位置创建文件夹,并保存断点文件

e331992c-20d9-11ed-ba43-dac502259ad0.png

三、模型的断点继续训练


if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch

指出这里的是否继续训练,及训练的checkpoint的文件位置等可以通过argparse从命令行直接读取,也可以通过log文件直接加载,也可以自己在代码中进行修改。关于argparse参照我的这一篇文章:

HUST小菜鸡:argparse 命令行选项、参数和子命令解析器

https://zhuanlan.zhihu.com/p/133285373

四、重点在于epoch的恢复


start_epoch = -1

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch


for epoch in  range(start_epoch + 1 ,EPOCH):    # print('EPOCH:',epoch)    for step, (b_img,b_label) in enumerate(train_loader):        train_output = model(b_img)        loss = loss_func(train_output,b_label)        # losses.append(loss)        optimizer.zero_grad()        loss.backward()        optimizer.step()

通过定义start_epoch变量来保证继续训练的时候epoch不会变化

e340886a-20d9-11ed-ba43-dac502259ad0.jpg

断点继续训练

一、初始化随机数种子


import torchimport randomimport numpy as np
def set_random_seed(seed = 10,deterministic=False,benchmark=False):    random.seed(seed)    np.random(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed_all(seed)    if deterministic:        torch.backends.cudnn.deterministic = True    if benchmark:        torch.backends.cudnn.benchmark = True

关于torch.backends.cudnn.deterministic和torch.backends.cudnn.benchmark详见

Pytorch学习0.01:cudnn.benchmark= True的设置

https://www.cnblogs.com/captain-dl/p/11938864.html

pytorch---之cudnn.benchmark和cudnn.deterministic_人工智能_zxyhhjs2017的博客

https://blog.csdn.net/zxyhhjs2017/article/details/91348108

e34baefc-20d9-11ed-ba43-dac502259ad0.png

benchmark用在输入尺寸一致,可以加速训练,deterministic用来固定内部随机性

二、多步长SGD继续训练

在简单的任务中,我们使用固定步长(也就是学习率LR)进行训练,但是如果学习率lr设置的过小的话,则会导致很难收敛,如果学习率很大的时候,就会导致在最小值附近,总会错过最小值,loss产生震荡,无法收敛。所以这要求我们要对于不同的训练阶段使用不同的学习率,一方面可以加快训练的过程,另一方面可以加快网络收敛。

采用多步长 torch.optim.lr_scheduler的多种步长设置方式来实现步长的控制,lr_scheduler的各种使用推荐参考如下教程

【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

https://www.cnblogs.com/devilmaycry812839668/p/10630302.html

所以我们在保存网络中的训练的参数的过程中,还需要保存lr_scheduler的state_dict,然后断点继续训练的时候恢复

#这里我设置了不同的epoch对应不同的学习率衰减,在10->20->30,学习率依次衰减为原来的0.1,即一个数量级lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
for epoch in range(start_epoch+1,80):    optimizer.zero_grad()    optimizer.step()    lr_schedule.step()
    if epoch %10 ==0:        print('epoch:',epoch)print('learningrate:',optimizer.state_dict()['param_groups'][0]['lr'])
lr的变化过程如下:

epoch: 10learning rate: 0.1epoch: 20learning rate: 0.010000000000000002epoch: 30learning rate: 0.0010000000000000002epoch: 40learning rate: 0.00010000000000000003epoch: 50learning rate: 1.0000000000000004e-05epoch: 60learning rate: 1.0000000000000004e-06epoch: 70learning rate: 1.0000000000000004e-06

我们在保存的时候,也需要对lr_scheduler的state_dict进行保存,断点继续训练的时候也需要恢复lr_scheduler

#加载恢复if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler


#保存for epoch in range(start_epoch+1,80):
    optimizer.zero_grad()
    optimizer.step()    lr_schedule.step()

    if epoch %10 ==0:        print('epoch:',epoch)        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])        checkpoint = {            "net": model.state_dict(),            'optimizer': optimizer.state_dict(),            "epoch": epoch,            'lr_schedule': lr_schedule.state_dict()        }        if not os.path.isdir("./model_parameter/test"):            os.mkdir("./model_parameter/test")torch.save(checkpoint,'./model_parameter/test/ckpt_best_%s.pth'%(str(epoch)))

三、保存最好的结果

每一个epoch中的每个step会有不同的结果,可以保存每一代最好的结果,用于后续的训练

第一次实验代码

RESUME = True
EPOCH = 40LR = 0.0005

model = cifar10_cnn.CIFAR10_CNN()
print(model)optimizer = torch.optim.Adam(model.parameters(),lr=LR)loss_func = nn.CrossEntropyLoss()
start_epoch = -1

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch


for epoch in  range(start_epoch + 1 ,EPOCH):    # print('EPOCH:',epoch)    for step, (b_img,b_label) in enumerate(train_loader):        train_output = model(b_img)        loss = loss_func(train_output,b_label)        # losses.append(loss)        optimizer.zero_grad()        loss.backward()        optimizer.step()
        if step % 100 == 0:            now = time.time()            print('EPOCH:',epoch,'| step :',step,'| loss :',loss.data.numpy(),'| train time: %.4f'%(now-start_time))
    checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")torch.save(checkpoint,'./models/checkpoint/ckpt_best_%s.pth'%(str(epoch)))

更新实验代码

optimizer = torch.optim.SGD(model.parameters(),lr=0.1)lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)start_epoch = 9# print(schedule)

if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch    lr_schedule.load_state_dict(checkpoint['lr_schedule'])
for epoch in range(start_epoch+1,80):
    optimizer.zero_grad()
    optimizer.step()    lr_schedule.step()

    if epoch %10 ==0:        print('epoch:',epoch)        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])        checkpoint = {            "net": model.state_dict(),            'optimizer': optimizer.state_dict(),            "epoch": epoch,            'lr_schedule': lr_schedule.state_dict()        }        if not os.path.isdir("./model_parameter/test"):            os.mkdir("./model_parameter/test")torch.save(checkpoint,'./model_parameter/test/ckpt_best_%s.pth'%(str(epoch)))
审核编辑:彭静
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据
    +关注

    关注

    8

    文章

    6504

    浏览量

    87447
  • 硬盘
    +关注

    关注

    3

    文章

    1225

    浏览量

    56194
  • pytorch
    +关注

    关注

    2

    文章

    758

    浏览量

    12794

原文标题:实操教程|PyTorch实现断点继续训练

文章出处:【微信号:vision263com,微信公众号:新机器视觉】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    双绞线布线要注意哪些要点

      双绞线在施工布线的时候也要注意六大要点,下面本文给大家具体分析一下。  一、双绞线必须符合国家双绞线产品标准:  1、所用双绞线要为纯铜芯,线径为0.5毫米。  2、单芯百米电阻为9.38欧姆
    发表于 07-18 11:45

    Pytorch模型训练实用PDF教程【中文】

    及优化器,从而给大家带来清晰的机器学习结构。通过本教程,希望能够给大家带来一个清晰的模型训练结构。当模型训练遇到问题时,需要通过可视化工具对数据、模型、损失等内容进行观察,分析并定位问题出在数据部分
    发表于 12-21 09:18

    电机驱动MCU技术有哪些要点需要注意

    电机驱动MCU技术有哪些要点需要注意
    发表于 04-09 06:19

    嵌入式系统设计时需要注意的技术要点实现细节有哪些?

    为什么需要safe mode?嵌入式系统设计时需要注意的技术要点实现细节有哪些?
    发表于 04-25 08:49

    pytorch模型转换需要注意的事项有哪些?

    什么是JIT(torch.jit)? 答:JIT(Just-In-Time)是一组编译工具,用于弥合PyTorch研究与生产之间的差距。它允许创建可以在不依赖Python解释器的情况下运行的模型
    发表于 09-18 08:05

    选择电磁阀要注意的四大要点

    电磁阀选择要注意四大要点“适用性、可靠性、安全性、经济性”
    的头像 发表于 06-13 17:30 2713次阅读

    使用威格士叶片泵时需要注意什么

    很多用户在运用威格士叶片泵期间都需要注意使用的,如果我们不注意,会给泵带来事故现象,当威格士叶片泵容积泵中的滑片泵应该需要注意哪些事项?威格士叶片泵的管理要点除需防干转和过载、防吸入空
    发表于 09-02 17:20 360次阅读

    使用贴片三极管时需要注意哪些问题

      在使用贴片三极管过程,需要注意的地方有很多,如果使用不当,会降低工作效率以及它的使用寿命?下面由平尚小编分享以下几点需要注意的。
    发表于 02-12 10:53 1280次阅读

    电子琴设计中要注意哪些要点

    引起了很多同学的兴趣,活动正式发布出来,就有不少同学纷纷下单。在这里顺便给同学们梳理一下要做出这个电子琴需要用到哪些书本知识?设计中要注意哪些要点
    的头像 发表于 07-01 16:43 1354次阅读

    地埋灯的安装需要注意哪些事项?

    哪些事项?接下来大成智慧地埋灯厂家的小编就来给大家做详细的介绍,其实地埋灯的安装需要注意以下几点: 1、地埋灯安装前必须切断电源,这是所有电气设备安装和安全运行的第一步。 2、整理所需的配件。地埋灯是埋在地下的特殊景观灯。
    的头像 发表于 04-14 11:32 991次阅读

    PyTorch教程15.10之预训练BERT

    电子发烧友网站提供《PyTorch教程15.10之预训练BERT.pdf》资料免费下载
    发表于 06-05 10:53 0次下载
    <b class='flag-5'>PyTorch</b>教程15.10之预<b class='flag-5'>训练</b>BERT

    无铅锡膏印刷时需要注意哪些技术要点?

    这种比较活跃性,很多人在焊接中没有多去注意这些事情,一般这种辅料只存半年多,所以,很多行业在生产中,需要注意什么要点,这个希望大家都能清楚,下面佳金源锡膏厂家说明一下:一般在注意的情况
    的头像 发表于 01-25 14:57 421次阅读
    无铅锡膏印刷时<b class='flag-5'>需要注意</b>哪些技术<b class='flag-5'>要点</b>?

    使用安全光幕有哪些需要注意的吗?

    使用安全光幕有哪些需要注意的吗?
    的头像 发表于 06-29 09:38 378次阅读
    使用安全光幕有哪些<b class='flag-5'>需要注意</b>的吗?

    螺杆支撑座在使用中需要注意的事项

    螺杆支撑座在使用中需要注意的事项
    的头像 发表于 04-10 17:59 489次阅读
    螺杆支撑座在使用中<b class='flag-5'>需要注意</b>的事项

    设计软板pcb需要注意哪些事项

    设计软板pcb需要注意哪些事项
    的头像 发表于 12-19 10:06 240次阅读