tnblog
首页
视频
资源
登录

Pytorch cifar10识别普适物体(易化学习笔记四)

1474人阅读 2024/6/10 13:47 总访问:3511524 评论:0 收藏:0 手机
分类: python

Pytorch cifar10识别普适物体(易化学习笔记四)

CIFAR-10简介


CIFAR-10(Canadian Institute For Advanced Research)是一个广泛用于机器学习和计算机视觉研究的标准数据集,主要用于图像识别任务。它由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton在2009年创建,包含10个不同类别的普适物体。每个类别有6000张32x32彩色图像,总共有60000张图像,其中50000张用于训练,10000张用于测试。这些类别包括飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

代码实践

数据处理

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. # 转换的集合
  5. transform = transforms.Compose(
  6. [transforms.ToTensor(), # 归一化,转为[0,1.0] shape[C,H,W]的张量
  7. # 正则化:前(0.5, 0.5, 0.5)是RGB通道均值,后(0.5, 0.5,0.5)是RGB通道标准差 -> 减少泛化误差(防过拟合)
  8. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  9. # 每批取的图片数(这里取4张)
  10. batch_size = 4
  11. # 获取训练集
  12. trainset = torchvision.datasets.CIFAR10(
  13. root='./data', # 设置数据集的根目录
  14. train=True, # 是训练集
  15. download=True, # 如本地无,则从网络下载
  16. transform=transform) # 设置转换函数
  17. # 载入训练集
  18. trainloader = torch.utils.data.DataLoader(
  19. trainset, # 指定载入训练集
  20. batch_size=batch_size, # 每批取的数目
  21. shuffle=True, # 乱序打包
  22. num_workers=2) # 设置多线程数(加num_workers 有的设备可能报错 )
  23. # 获取测试集
  24. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  25. download=True, transform=transform)
  26. # 载入测试集
  27. testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
  28. shuffle=False, num_workers=2)
  29. classes = ('plane', 'car', 'bird', 'cat',
  30. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


从训练集中查看一张图片。

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. # 显示图像
  4. def imshow(img):
  5. img = img / 2 + 0.5
  6. npimg = img.numpy()
  7. # 把(channel,height,weight) 转为Matplotlib能识别的 (height,weight,channel)
  8. show_img = np.transpose(npimg, (1, 2, 0))
  9. # (36, 138, 3) 合并后的图片高36宽138,3个通道(rgb)
  10. print(show_img.shape)
  11. plt.imshow(show_img)
  12. plt.show()
  13. # 随机获取一批图像样本
  14. dataiter = iter(trainloader)
  15. images, labels = next(dataiter)
  16. # torch.Size([4, 3, 32, 32]) 4张 3通道(RGB) 的32*32的图片
  17. print(images.shape)
  18. # tensor([6, 6, 3, 7]) 4张图片分别对应的类别字典的索引(位置)
  19. print(labels)
  20. # 拼成一幅图像显示: 把4维 (batch_size,channel,height,weight) 变为3维 (channel,height,weight)
  21. imshow(torchvision.utils.make_grid(images))
  22. print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
  23. # 根据索引值查找类别名
  24. # 如 frog frog cat horse (图像对应的类别名)

构建模型

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. import torch.optim as optim
  4. # 定义网络模型
  5. class Net(nn.Module):
  6. def __init__(self):
  7. super().__init__()
  8. self.conv1 = nn.Conv2d(3, 6, 5) # 卷积1: 输入3个RGB通道32*32的图片,输出6张特征图, 用5x5的卷积核
  9. self.pool = nn.MaxPool2d(2, 2) # 最大值池化: 2*2的卷积核, 步长2
  10. self.conv2 = nn.Conv2d(6, 16, 5) # 卷积2: 输入6,输出16,5*5的卷积核
  11. self.fc1 = nn.Linear(16 * 5 * 5, 120) # 线性回归 (in:16*5*5 out:120)
  12. self.fc2 = nn.Linear(120, 84)
  13. self.fc3 = nn.Linear(84, 10)
  14. # 前向传播
  15. def forward(self, x):
  16. x = self.pool(F.relu(self.conv1(x))) # 构建卷积层1(Conv2d->relu->MaxPool2d)
  17. x = self.pool(F.relu(self.conv2(x))) # 构建卷积层2(Conv2d->relu->MaxPool2d)
  18. x = torch.flatten(x, 1) # 扁平化所有维度,除了batch外 -> 方便对接全连接层
  19. # 替代老式做法 x = x.view(-1, 16 * 5 * 5) 扁平化:行-1自动推导 列未16*5*5, 把16张 5*5的特征图压平为一维的点
  20. x = F.relu(self.fc1(x)) # 构建全连接层1( Linear->relu)
  21. x = F.relu(self.fc2(x))
  22. x = self.fc3(x)
  23. return x
  24. net = Net() # 新建模型
  25. criterion = nn.CrossEntropyLoss() # 损失函数用交叉熵
  26. optimizer = optim.Adam(net.parameters(), lr=0.001)

训练评估

  1. # 训练8轮(注
  2. for epoch in range(8):
  3. running_loss = 0.0
  4. for i, data in enumerate(trainloader, 0):
  5. inputs, labels = data # 获取 [inputs:输入图片 labels:图片对应的数字标签]
  6. optimizer.zero_grad() # 梯度清零: 每次迭代都需梯度清零,因pytorch默认会累积梯度
  7. # 前向传播
  8. outputs = net(inputs) # 用模型对输入进行预测
  9. loss = criterion(outputs, labels) # 计算损失(误差)
  10. # 反向传播
  11. loss.backward() # 反向传播,计算梯度
  12. optimizer.step() # 优化一步(梯度下降)
  13. running_loss += loss.item()
  14. if i % 2000 == 1999: # 每2000批,打印一次统计
  15. # 打印如 [2, 10000] loss: 1.274 第2轮训练 训练样本10000 损失为1.274
  16. print('[%d, %5d] loss: %.3f' %
  17. (epoch + 1, i + 1, running_loss / 2000))
  18. running_loss = 0.0
  19. print('Finished Training')


评估1:用1000个测试图片预测,正确率是多少

  1. correct = 0
  2. total = 0
  3. # 因为我们没有训练,所以我们不需要计算输出的梯度
  4. with torch.no_grad():
  5. for data in testloader:
  6. images, labels = data
  7. # 用模型对输入进行预测
  8. outputs = net(images)
  9. # 取概率最大做预测值(字典里的索引位置)
  10. _, predicted = torch.max(outputs.data, 1)
  11. # 预测总数
  12. total += labels.size(0)
  13. # 预测正确的次数
  14. correct += (predicted == labels).sum().item()
  15. # 输出1000个测试图片中,正确率(精度: 正确率= 正确的次数/总数)是多少
  16. print('Accuracy of the network on the 10000 test images: %d %%' % (
  17. 100 * correct / total))
  18. # 如 Accuracy of the network on the 10000 test images: 54 %


评估:对10个类别预测,那一类的正确率更高

  1. class_correct = list(0. for i in range(10))
  2. class_total = list(0. for i in range(10))
  3. with torch.no_grad():
  4. for data in testloader:
  5. images, labels = data
  6. outputs = net(images)
  7. _, predicted = torch.max(outputs, 1)
  8. c = (predicted == labels).squeeze()
  9. for i in range(4):
  10. label = labels[i]
  11. class_correct[label] += c[i].item()
  12. class_total[label] += 1
  13. for i in range(10):
  14. print('Accuracy of %5s : %2d %%' % (
  15. classes[i], 100 * class_correct[i] / class_total[i]))


保存模型

  1. PATH = './cifar_net.pth'
  2. torch.save(net.state_dict(), PATH) # 保存模型的参数

从测试集中随机取一批图像,看预测效果

  1. # 从测试集中随机取一批图像
  2. images, labels = next(iter(testloader))
  3. # 查看图像
  4. imshow(torchvision.utils.make_grid(images))
  5. # 查看标签
  6. print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
  7. # 新建模型,用于测试评估(必须与训练时的模型一致)
  8. net = Net()
  9. # 载入已训练好的模型参数
  10. net.load_state_dict(torch.load(PATH))
  11. # 预测:用模型识别图像
  12. outputs = net(images)
  13. # 取概率最大做预测值(字典里的索引位置)
  14. _, predicted = torch.max(outputs, 1)
  15. # 打印预测的图像的类别(如cat horse dog bird)
  16. print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))


欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739

评价

C ?、?? 问号和2个问号的用法类型?、对象?

C# ?C# ???:单问号1.定义数据类型可为空。可用于对int,double,bool等无法直接赋值为null的数据类型进行null的赋值如这...

Python实例 1-日志抓取处理 补错附日志小技巧

有时候数据出了问题,可以从日志中恢复数据(如果你没记日志..没备份..→_→..)一、日志展示介绍个平常自己用的小方法,如...

C 数组拆分泛型

主要用到了泛型。泛型是c#2.0的一个新增加的特性,它为使用c#语言编写面向对象程序增加了极大的效力和灵活性。不会强行对值...

MySQL 视图的增删改 查

要显示视图的定义,需要在SHOWCREATEVIEW子句之后指定视图的名称, 我们先来创建几张表,完事后在进行演示:--用户信息表...

使用NPOI导出excel包括图片

Excl模板导出相信我们都会,那么模板上要导出图片呢?嗯~还是来个例子:准备工作:首先要引用NPOI包:然后获取数据集(我这...

ajaxSubmit异步上传图片嘘,外面都是假的

引用代码<scriptsrc="/Scripts/jquery.form.js"></script>js就在旁边img链接中,只不过大小为0x0,...

.NET MVC 使用百度编辑器详细教程:1配置编辑器

一、什么是百度编辑器百度编辑器UEditor是由百度web前端研发部开发一款应用于网站的编辑器,具有轻量,可定制,注重用户体...

使用jquery操作元素的css样式获取、修改等等

使用jquery操作元素的css样式(获取、修改等等) //1、获取和设置样式 $("#tow").attr("class")...

.net辗转java系列视野

.net辗转java系列(一)视野.net系java系其它语言C#Java框架.net Framework Standardjava se.net corejava eejave meJava S...

.NET MVC json对象或者json对象数组的序列化和反序列化

1、用JSON.stringify()将对象stuarr或者json数组stuarr序列化成字符串,然后提交给后台。$.post("/home/DoUpdate&quot...

.NET MVC json对象或者json对象数组的序列化和反序列化

1、用JSON.stringify()将对象stuarr或者json数组stuarr序列化成字符串,然后提交给后台。$.post("/home/DoUpdate&quot...

mui框架-移动端跳转以及传值的简单方法修改解决方法

纠结了两天的MUI跳转的问题,终于解决了 ,现在分享给大家,希望大家有什么坑的解决也给我分享分享 哈哈,废话不多说,上代...

MVC全局异常处理错误日记

1、在Filter文件夹中创建一个IsExceptionFilter类(类名随意取)2、使用3、在访问的页面控制器中添加几个错误4、在IsExcept...

MVC全局异常处理错误日记

1、在Filter文件夹中创建一个IsExceptionFilter类(类名随意取)2、使用3、在访问的页面控制器中添加几个错误4、在IsExcept...

Hbuilder打包APP的教程会操作的略过

首先项目必须是APP端的,可能讲解有点啰嗦,讲解准备的工具:HBuilderX(其他版本也可以,这里用X版本来讲解)、待测试手机...
这一世以无限游戏为使命!
排名
2
文章
642
粉丝
44
评论
93
docker中Sware集群与service
尘叶心繁 : 想学呀!我教你呀
一个bug让程序员走上法庭 索赔金额达400亿日元
叼着奶瓶逛酒吧 : 所以说做程序员也要懂点法律知识
.net core 塑形资源
剑轩 : 收藏收藏
映射AutoMapper
剑轩 : 好是好,这个对效率影响大不大哇,效率高不高
ASP.NET Core 服务注册生命周期
剑轩 : http://www.tnblog.net/aojiancc2/article/details/167
ICP备案 :渝ICP备18016597号-1
网站信息:2018-2025TNBLOG.NET
技术交流:群号656732739
联系我们:contact@tnblog.net
公网安备:50010702506256
欢迎加群交流技术