tnblog
首页
视频
资源
登录

Pytorch 卷积神经网络效果

1909人阅读 2023/12/29 15:06 总访问:3475666 评论:0 收藏:0 手机
分类: pytorch

Pytorch 卷积神经网络效果

数据集Dataloader制作

如何自定义数据集:


1.数据和标签的目录结构先搞定(得知道到哪读数据)
2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
3.完成单个数据与标签读取函数(给dataloader举一个例子)

咱们以花朵数据集为例:


原来数据集都是以文件夹为类别ID,现在咱们换一个套路,用txt文件指定数据路径与标签(实际情况基本都这样)
这回咱们的任务就是在txt文件中获取图像路径与标签,然后把他们交给dataloader
核心代码非常简单,按照对应格式传递需要的数据和标签就可以啦

  1. import os
  2. import matplotlib.pyplot as plt
  3. %matplotlib inline
  4. import numpy as np
  5. import torch
  6. from torch import nn
  7. import torch.optim as optim
  8. import torchvision
  9. #pip install torchvision
  10. from torchvision import transforms, models, datasets
  11. #https://pytorch.org/docs/stable/torchvision/index.html
  12. import imageio
  13. import time
  14. import warnings
  15. import random
  16. import sys
  17. import copy
  18. import json
  19. from PIL import Image

先来分细节整明白咱一会要干啥!

任务1:读取txt文件中的路径和标签


第一个小任务,从标注文件中读取数据和标签
至于你准备存成什么格式,都可以的,一会能取出来东西就行

  1. def load_annotations(ann_file):
  2. """
  3. 加载数据标签的方法
  4. """
  5. # 定义一个集合
  6. data_infos = {}
  7. with open(ann_file) as f:
  8. # 数据格式:image_06739.jpg 0
  9. # 读取每一行,strip方法去掉\r\n,split进行空格分割字符串
  10. samples = [x.strip().split(' ') for x in f.readlines()]
  11. # 存入集合中去
  12. for filename, gt_label in samples:
  13. # 使用np.array进行数据处理
  14. data_infos[filename] = np.array(gt_label, dtype=np.int64)
  15. return data_infos
  1. print(load_annotations('./train.txt'))

任务2:分别把数据和标签都存在list里


不是我非让你存list里,因为dataloader到时候会在这里取数据
按照人家要求来,不要耍个性,让整list咱就给人家整

  1. img_label = load_annotations('./train.txt')
  1. # 将图片名字转成list格式
  2. image_name = list(img_label.keys())
  3. # 将图片标签转成list格式
  4. label = list(img_label.values())
  1. image_name

  1. label

任务3:图像数据路径得完整


因为一会咱得用这个路径去读数据,所以路径得加上前缀
以后大家任务不同,数据不同,怎么加你看着来就行,反正得能读到图像

  1. data_dir = './'
  2. train_dir = data_dir + 'train_filelist'
  3. valid_dir = data_dir + 'val_filelist'
  1. # 路径拼接 join 第一个参数合并路径的前缀,通过遍历图片集进行拼接。
  2. image_path = [os.path.join(train_dir,img) for img in image_name]
  3. image_path

任务4:把上面那几个事得写在一起


1.注意要使用from torch.utils.data import Dataset, DataLoader
2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
3.def init(self, rootdir, annfile, transform=None):咱们要根据自己任务重写
4.def __getitem
(self, idx):根据自己任务,返回图像数据和标签数据

  1. from torch.utils.data import Dataset, DataLoader
  2. class FlowerDataset(Dataset):
  3. def __init__(self, root_dir, ann_file, transform=None):
  4. # 获取标签文件路径
  5. self.ann_file = ann_file
  6. # 获取当前数据路径
  7. self.root_dir = root_dir
  8. # 获取数据与标签
  9. self.img_label = self.load_annotations()
  10. # 获取相对图片路径
  11. self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
  12. # 获取标签
  13. self.label = [label for label in list(self.img_label.values())]
  14. # 预处理
  15. self.transform = transform
  16. def __len__(self):
  17. return len(self.img)
  18. def __getitem__(self, idx):
  19. """
  20. idx是随机id,根据idx获取数据和标签
  21. """
  22. # 读取图片文件
  23. image = Image.open(self.img[idx])
  24. # 读取标签
  25. label = self.label[idx]
  26. # 处理图像数据
  27. if self.transform:
  28. image = self.transform(image)
  29. # 将numpy格式转成tensor对格式
  30. label = torch.from_numpy(np.array(label))
  31. # 返回图像数据和标签
  32. return image, label
  33. def load_annotations(self):
  34. """
  35. 加载数据标签的方法
  36. """
  37. data_infos = {}
  38. with open(self.ann_file) as f:
  39. samples = [x.strip().split(' ') for x in f.readlines()]
  40. for filename, gt_label in samples:
  41. data_infos[filename] = np.array(gt_label, dtype=np.int64)
  42. return data_infos

任务5:数据预处理(transform)


1.预处理的事都在上面的getitem中完成,需要对图像和标签咋咋地的,要整啥事,都在上面整
2.返回的数据和标签就是建模时模型的输入和损失函数中标签的输入,一定整明白自己模型要啥
3.预处理这个事是你定的,不同的数据需要的方法也不一样,下面给出的是比较通用的方法

  1. data_transforms = {
  2. 'train':
  3. transforms.Compose([
  4. transforms.Resize(64),
  5. transforms.RandomRotation(45),#随机旋转,-4545度之间随机选
  6. transforms.CenterCrop(64),#从中心开始裁剪
  7. transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
  8. transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
  9. transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
  10. transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
  11. transforms.ToTensor(),
  12. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
  13. ]),
  14. 'valid':
  15. transforms.Compose([
  16. transforms.Resize(64),
  17. transforms.CenterCrop(64),
  18. transforms.ToTensor(),
  19. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  20. ]),
  21. }

任务6:根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader


1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
3.打印看看数据里面是不是有东西了

  1. # 训练集 图片数据、数据标签、预处理(数据增强)
  2. train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './train.txt', transform=data_transforms['train'])
  1. # 验证集
  2. val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './val.txt', transform=data_transforms['valid'])
  1. # 创建DataLoader 训练集和验证集,每次压入cpu的byte,是否随机
  2. train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  3. val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
  1. len(train_dataset)

6552

  1. len(val_dataset)

818

任务7:用之前先试试,整个数据和标签对应下,看看对不对


1.别着急往模型里传,对不对都不知道呢
2.用这个方法:iter(train_loader).next()来试试,得到的数据和标签是啥
3.看不出来就把图画出来,标签打印出来,确保自己整的数据集没啥问题

  1. # iter迭代一次 next取一个byte数据
  2. image, label = next(iter(train_loader))
  3. image.shape

torch.Size([64, 3, 64, 64])


第一个64是byte
最后两个64是长和宽

  1. image, label = next(iter(train_loader))
  2. # 取其中一个数据squeeze压缩一个维度,举例:1*3*64*64 压缩后 3*64*64
  3. sample = image[0].squeeze()
  4. # 根据索引替换位置,换成numpy格式
  5. sample = sample.permute((1, 2, 0)).numpy()
  6. # 还原图,因为当时做了均值差 transforms.Normalize
  7. sample *= [0.229, 0.224, 0.225]
  8. sample += [0.485, 0.456, 0.406]
  9. # 展示
  10. plt.imshow(sample)
  11. plt.show()
  12. print('Label is: {}'.format(label[0].numpy()))

  1. # 验证集
  2. image, label = next(iter(val_loader))
  3. sample = image[0].squeeze()
  4. sample = sample.permute((1, 2, 0)).numpy()
  5. sample *= [0.229, 0.224, 0.225]
  6. sample += [0.485, 0.456, 0.406]
  7. plt.imshow(sample)
  8. plt.show()
  9. print('Label is: {}'.format(label[0].numpy()))

任务8:咋用就是你来定了,把模型啥的整好往里面传吧


下面这些事之前都唠过了,按照自己习惯的方法整就得了

  1. # 放入dataloader
  2. dataloaders = {'train':train_loader,'valid':val_loader}
  1. model_name = 'resnet' #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
  2. #是否用人家训练好的特征来做
  3. feature_extract = True
  1. # 是否用GPU训练
  2. train_on_gpu = torch.cuda.is_available()
  3. if not train_on_gpu:
  4. print('CUDA is not available. Training on CPU ...')
  5. else:
  6. print('CUDA is available! Training on GPU ...')
  7. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


CUDA is not available. Training on CPU …

  1. model_ft = models.resnet18()
  2. model_ft

  1. num_ftrs = model_ft.fc.in_features
  2. model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
  3. input_size = 64
  4. model_ft

  1. # 优化器设置
  2. optimizer_ft = optim.Adam(model_ft.parameters(), lr=1e-3)
  3. scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7epoch衰减成原来的1/10
  4. criterion = nn.CrossEntropyLoss()
  1. def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, filename='best.pth'):
  2. since = time.time()
  3. best_acc = 0
  4. model.to(device)
  5. val_acc_history = []
  6. train_acc_history = []
  7. train_losses = []
  8. valid_losses = []
  9. LRs = [optimizer.param_groups[0]['lr']]
  10. best_model_wts = copy.deepcopy(model.state_dict())
  11. for epoch in range(num_epochs):
  12. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  13. print('-' * 10)
  14. # 训练和验证
  15. for phase in ['train', 'valid']:
  16. if phase == 'train':
  17. model.train() # 训练
  18. else:
  19. model.eval() # 验证
  20. running_loss = 0.0
  21. running_corrects = 0
  22. # 把数据都取个遍
  23. for inputs, labels in dataloaders[phase]:
  24. inputs = inputs.to(device)
  25. labels = labels.to(device)
  26. # 清零
  27. optimizer.zero_grad()
  28. # 只有训练的时候计算和更新梯度
  29. with torch.set_grad_enabled(phase == 'train'):
  30. outputs = model(inputs)
  31. loss = criterion(outputs, labels)
  32. _, preds = torch.max(outputs, 1)
  33. #print(loss)
  34. # 训练阶段更新权重
  35. if phase == 'train':
  36. loss.backward()
  37. optimizer.step()
  38. # 计算损失
  39. running_loss += loss.item() * inputs.size(0)
  40. running_corrects += torch.sum(preds == labels.data)
  41. epoch_loss = running_loss / len(dataloaders[phase].dataset)
  42. epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
  43. time_elapsed = time.time() - since
  44. print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  45. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  46. # 得到最好那次的模型
  47. if phase == 'valid' and epoch_acc > best_acc:
  48. best_acc = epoch_acc
  49. best_model_wts = copy.deepcopy(model.state_dict())
  50. state = {
  51. 'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
  52. 'best_acc': best_acc,
  53. 'optimizer' : optimizer.state_dict(),#优化器的状态信息
  54. }
  55. torch.save(state, filename)
  56. if phase == 'valid':
  57. val_acc_history.append(epoch_acc)
  58. valid_losses.append(epoch_loss)
  59. scheduler.step(epoch_loss)#学习率衰减
  60. if phase == 'train':
  61. train_acc_history.append(epoch_acc)
  62. train_losses.append(epoch_loss)
  63. print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
  64. LRs.append(optimizer.param_groups[0]['lr'])
  65. print()
  66. time_elapsed = time.time() - since
  67. print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  68. print('Best val Acc: {:4f}'.format(best_acc))
  69. # 训练完后用最好的一次当做模型最终的结果,等着一会测试
  70. model.load_state_dict(best_model_wts)
  71. return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
  1. model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')


欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739

评价

Pytorch 初探

Pytorch 初探[TOC] Pytorch简介PyTorch由 Facebook 的 AI 研究团队开发的一个开源的机器学习库,它提供了张量(tensor)计...

Pytorch 自动求导与简单的线性回归

Pytorch 自动求导与简单的线性回归[TOC] 环境安装安装pytorch%pip install torch torchvision torchaudio 自动计算反向...

Pytorch Tensor 常见的形式

Pytorch Tensor 常见的形式[TOC] 值 描述 scalar 0维张量 vector 1维张量 matrix 2维张量 ...

Pytorch 气温预测

Pytorch 气温预测[TOC] 准备数据集下载temps.csv数据集。# 下载包 !wget https://raw.githubusercontent.com/AiDaShi/lea...

Pytorch Mnist分类任务

Pytorch Mnist分类任务[TOC] Mnist分类任务了解目标——网络基本构建与训练方法,常用函数解析——torch.nn.functional...

Pytorch 卷积神经网络效果

Pytorch 卷积神经网络效果[TOC] 数据与在线实践数据链接: https://pan.baidu.com/s/1VkrHDZGukkF900zLncMn5g 密码: 3lom...

Pytorch 基于经典网络架构训练图像分类模型

Pytorch 基于经典网络架构训练图像分类模型[TOC] 数据预处理部分:数据增强:torchvision中transforms模块自带功能,比较...

Pytorch 新闻分类任务(学习笔记)

Pytorch 新闻分类任务(学习笔记)[TOC] 目录结构 models文件夹该文件夹显示搭建的网络结构。里面有TextCNN.py和TextRNN....

Pytorch Flask服务部署图片识别(学习笔记)

Pytorch Flask服务部署图片识别(学习笔记)[TOC] Flask 简介Flask是一个用Python编写的轻量级Web应用框架。它简单易用,...

Pytorch 预测产量(易化学习笔记一)

Pytorch 预测产量(易化学习笔记一)[TOC] 实验目的(二维)通过温度进行产量预测。 实验代码导入数据集import torch im...

Pytorch 曲线拟合(易化学习笔记二)

Pytorch 曲线拟合(易化学习笔记二)[TOC] 感染与天数预测import matplotlib.pyplot as plt import torch import torch....

Pytorch 识别手写数字(易化学习笔记三)

Pytorch 识别手写数字(易化学习笔记三)[TOC] 识别手写数字LeNet-5手写数字识别的非常高效的卷积神经网络。高效原因:1....

Pytorch cifar10识别普适物体(易化学习笔记四)

Pytorch cifar10识别普适物体(易化学习笔记四)[TOC] CIFAR-10简介CIFAR-10(Canadian Institute For Advanced Research...

Pytorch loguru日志收集(易化学习笔记五)

Pytorch loguru日志收集(易化学习笔记五)[TOC] loguru日志简介Loguru 是一个 Python 日志库,旨在简化日志记录的设置和...

Pytorch TensorBoard运用(易化学习笔记六)

Pytorch TensorBoard运用(易化学习笔记六)[TOC] TensorBoard简介TensorBoard是TensorFlow的可视化工具包,旨在帮助研究...
这一世以无限游戏为使命!
排名
2
文章
636
粉丝
44
评论
93
docker中Sware集群与service
尘叶心繁 : 想学呀!我教你呀
一个bug让程序员走上法庭 索赔金额达400亿日元
叼着奶瓶逛酒吧 : 所以说做程序员也要懂点法律知识
.net core 塑形资源
剑轩 : 收藏收藏
映射AutoMapper
剑轩 : 好是好,这个对效率影响大不大哇,效率高不高
ASP.NET Core 服务注册生命周期
剑轩 : http://www.tnblog.net/aojiancc2/article/details/167
ICP备案 :渝ICP备18016597号-1
网站信息:2018-2025TNBLOG.NET
技术交流:群号656732739
联系我们:contact@tnblog.net
公网安备:50010702506256
欢迎加群交流技术