快上网专注成都网站设计 成都网站制作 成都网站建设
成都网站建设公司服务热线:028-86922220

网站建设知识

十年网站开发经验 + 多家企业客户 + 靠谱的建站团队

量身定制 + 运营维护+专业推广+无忧售后,网站问题一站解决

pytorch中如何使用迁移学习resnet18训练mnist数据集

pytorch中如何使用迁移学习resnet18训练mnist数据集,相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。

站在用户的角度思考问题,与客户深入沟通,找到叙州网站设计与叙州网站推广的解决方案,凭借多年的经验,让设计与互联网技术结合,创造个性化、用户体验好的作品,建站类型包括:做网站、成都网站设计、企业官网、英文网站、手机端网站、网站推广、域名注册、网页空间、企业邮箱。业务覆盖叙州地区。

预备知识

  • 自己搭建cnn模型训练mnist(不使用迁移学习)

https://blog.csdn.net/qq_42951560/article/details/109565625

  • pytorch官方的迁移学习教程(蚂蚁、蜜蜂分类)

https://blog.csdn.net/qq_42951560/article/details/109950786

学习目标

今天我们尝试在pytorch中使用迁移学习来训练mnist数据集。

如何迁移

预训练模型

迁移学习需要选择一个预训练模型,我们这个任务也不是特别大,选择resnet18就行了。

数据预处理

resnet18输入的CHW(3, 224, 224)

mnist数据集中单张图片CHW(1, 28, 28)

所以我们需要对mnist数据集做一下预处理:
pytorch中如何使用迁移学习resnet18训练mnist数据集

# 预处理my_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.Grayscale(3),transforms.ToTensor(),transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081)),])# 训练集train_file = datasets.MNIST(root='./dataset/',train=True,transform=my_transform)# 测试集test_file = datasets.MNIST(root='./dataset/',train=False,transform=my_transform)

pytorch中数据增强和图像处理的教程(torchvision.transforms)可以看我的这篇文章

改全连接层

resnet18是在imagenet上训练的,输出特征数是1000;而对于mnist来说,需要分10类,因此要改一下全连接层的输出。

model = models.resnet18(pretrained=True)in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10)

调整学习率

之前设置的Adam的学习率是1e-3,现在使用了迁移学习,所以学习率调小一点,改为1e-4

训练结果

resnet18相较于普通的一两层卷积网络来说已经比较深了,并且mnsit数据集还是挺大的,总共有7万张图片。为了节省时间,我们使用7张GeForce GTX 1080 Ti来训练:

  • 数据并行(DataParallel)

EPOCH: 01/10 STEP: 67/67 LOSS: 0.0266 ACC: 0.9940 VAL-LOSS: 0.0246 VAL-ACC: 0.9938 TOTAL-TIME: 102EPOCH: 02/10 STEP: 67/67 LOSS: 0.0141 ACC: 0.9973 VAL-LOSS: 0.0177 VAL-ACC: 0.9948 TOTAL-TIME: 80EPOCH: 03/10 STEP: 67/67 LOSS: 0.0067 ACC: 0.9990 VAL-LOSS: 0.0147 VAL-ACC: 0.9958 TOTAL-TIME: 80EPOCH: 04/10 STEP: 67/67 LOSS: 0.0042 ACC: 0.9995 VAL-LOSS: 0.0151 VAL-ACC: 0.9948 TOTAL-TIME: 80EPOCH: 05/10 STEP: 67/67 LOSS: 0.0029 ACC: 0.9997 VAL-LOSS: 0.0143 VAL-ACC: 0.9955 TOTAL-TIME: 80EPOCH: 06/10 STEP: 67/67 LOSS: 0.0019 ACC: 0.9999 VAL-LOSS: 0.0133 VAL-ACC: 0.9962 TOTAL-TIME: 80EPOCH: 07/10 STEP: 67/67 LOSS: 0.0013 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963 TOTAL-TIME: 80EPOCH: 08/10 STEP: 67/67 LOSS: 0.0008 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963 TOTAL-TIME: 79EPOCH: 09/10 STEP: 67/67 LOSS: 0.0006 ACC: 1.0000 VAL-LOSS: 0.0122 VAL-ACC: 0.9962 TOTAL-TIME: 79EPOCH: 10/10 STEP: 67/67 LOSS: 0.0005 ACC: 1.0000 VAL-LOSS: 0.0131 VAL-ACC: 0.9959 TOTAL-TIME: 79| BEST-MODEL | EPOCH: 07/10 STEP: 67/67 LOSS: 0.0013 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963

训练10轮,最佳的模型出现在第7轮,最大准确率是0.9963。在这篇文章中,我们自己搭了两层的卷积,也训练了10轮,最大准确率是0.9923。准确率提高了0.0040,我们要知道测试集共有1万张图片,也就是多预测对了40张图片,已经提升很高。当然,因为网络变深了,所以训练花费的时间也就增加了。

看完上述内容,你们掌握pytorch中如何使用迁移学习resnet18训练mnist数据集的方法了吗?如果还想学到更多技能或想了解更多相关内容,欢迎关注创新互联行业资讯频道,感谢各位的阅读!


当前标题:pytorch中如何使用迁移学习resnet18训练mnist数据集
浏览路径:http://6mz.cn/article/ijddhe.html

其他资讯