tnblog
首页
视频
资源
登录

python 卷积图像识别实战

6381人阅读 2023/5/7 17:32 总访问:3455086 评论:1 收藏:0 手机
分类: AI

python 卷积图像识别实战

前提回顾


我们通过卷积将原始图像,通过卷积核,得出卷积后的图像,这个中间的过程被称为卷积层,最后纳入神经网络中进行训练。
当然你可能有这样的疑惑,就是如何得出卷积核的值呢?


卷积核的值是通过训练而来的。
还是与以前一样通过反向传播、梯度下降得来的,区别在于提取数据的方式不同。

卷积核值原理


假设我们的原始图的大小只有4*4,用一个3*3的卷积核去卷它,得到一个2*2的矩阵,卷完以后我们送入到一个全连接层中。
然后通过代价函数,把误差代价传播到这两层的各个权值和偏置项中,但如何传到这个卷积层中呢?貌似就不行了。


我们首先通过一个随机的3*3的卷积核,在4*4的原始图中得出2*2的卷积图像。然后把这四个值当成一个四个神经元,并加上偏置项b中,并再带上Rule函数就与神经网络没什么两样了。


但是这里面的细节还是有些不一样,首先这四个神经元的输出是根据卷积的过程排列而成的二位结构,而不是这样的平铺,到后面我们送到全链接层的时候我们需要手动把它平铺开;然后这四个神经元的输入并不相同,实际上是同一个图片的不同位置,并且使用同一个卷积核,所以他们的偏置项是一样的(这也是所谓的参数共享)。


一个卷积核提取出来的结果是提取图像的一种特征,我们需要提取图像更多的特征。
实际上你想要多少个特征你就搞多少个卷积核就好了。


而这三个卷积核得出的结果是一个三维向量,我们可以将三维的向量铺开就是一个一维的向量,然后在后面构造全链接层。


但在我们卷全连接之前我们还可不可以继续卷这个6*6*3的数据呢?当然可以比如LeNet-5网络。

LeNet-5神经网络结构


LeNet-5网络中就卷了两次之后再送入全连接层,那我们怎么去卷这个6*6*3的三维数据呢?实际上卷积运算不仅可以在二维上进行,同样可以在3维数据上进行,我们在3维数据上我们用3维的卷积核。


3维卷积的方式和二维几乎一模一样,三维卷积先找到数据和卷积核立方块对应的位置的元素相乘,然后加起来得到一个结果,当然最后还需要给结果加上偏置项b,再经过激活做非线性运算得到最后的输出。


当然我们也可以添加4个神经元变成4*4*4的张量数据,被卷积核的数据的第3个维度也就是所谓的通道数(6*6*3)。为什么是通道数?


如果我们的图像是彩色的RGP三个通道,那我们的输入是一个3通道的图像,所以把数据的第三个维度值称为通道数量,要对这样的值进行卷积那么卷积核也必须是三维的,而且第三个维度值也必须和数据的通道数一样。


如果我们还想把第二层卷积的输出结果继续卷,卷积核的第三个维度值需要和数据的通道数一样,我们可以使用3*3*4的卷积核。


除了我们两个卷积层以外还多出了两个立方块,这两个立方块就是所谓的池化层。


我们从输出的左上角开始框出2*2(举例不一定非得2*2),再求出这个2*2数据的平均值得到第二个结果,然后一直顶到最右边。这个操作就是池化


还有一种常见的方法是取最大值,就是取4个数的最大值,就叫最大池化,前面采用平均值的方式也叫AveragePooling平均池化。

图片卷积变大变小公式


变小公式,举例把32*32的图片像素转成28*28图片像素需要一个5*5的卷积核。
填充完了之后信息损失越来越多,之后人们又提出了新的填充方案。


比如卷积核是5*5就填充两圈,这样原始图片就从28*28填充成了32*32,这种卷积模式就是Same模式


除了Same模式还有一种valid模式,越卷越小。

编程实践


模拟LeNet-5神经网络的实现。

  1. from keras.datasets import mnist
  2. import numpy as np
  3. from keras.models import Sequential
  4. from keras.layers import Dense
  5. from keras.optimizers import SGD
  6. import plot_utils_2
  7. import matplotlib.pyplot as plt
  8. from keras.utils.np_utils import to_categorical
  9. # 导入二维卷积层
  10. from keras.layers import Conv2D
  11. # 导入池化层
  12. from keras.layers import AveragePooling2D
  13. # 平铺数组
  14. from keras.layers import Flatten
  15. (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
  16. # ndarray 的 reshape函数改变数组的形状。
  17. # 把28*28的图片像素。
  18. # 除以255的目的在于降低梯度下降的复杂度
  19. # 训练集
  20. X_train = X_train.reshape(60000,28,28,1)/255
  21. # 测试集
  22. X_test = X_test.reshape(10000,28,28,1)/255
  23. # 转换成One Hot编码
  24. Y_train = to_categorical(Y_train,10)
  25. Y_test = to_categorical(Y_test,10)
  26. model = Sequential()
  27. # filter 卷积核/过滤器数量
  28. # kernel_size 卷积核尺寸
  29. # strides 步长(一般都是挪动一步,当然也可以挪动多步)
  30. # input_shape 输入形状(这里是一个28*28*1的张量)
  31. # padding 填充模式 valid模式越小
  32. # activation 激活函数 relu
  33. #
  34. # Convolutions
  35. model.add(Conv2D(filter=6,kernel_size=(5,5),strides=(1,1),input_shape=(28,28,1),padding='valid',activation='relu'))
  36. # 池化大小和窗口大小。
  37. # 注意:keras中池化操作的步长不指定默认和pool_size一样,这里的步长也就默认是(2,2)
  38. #
  39. # Subsamping
  40. model.add(AveragePooling2D(pool_size=(2,2)))
  41. # 卷积出8*8*16
  42. #
  43. # Convolutions
  44. model.add(Conv2D(filter=16,kernel_size=(5,5),strides=(1,1),padding='valid',activation='relu'))
  45. # 输出4*4*16
  46. #
  47. # Subsamping
  48. model.add(AveragePooling2D(pool_size=(2,2)))
  49. # Full connection 平铺
  50. model.add(Flatten())
  51. # 输出层
  52. model.add(Dense(units=120, activation='relu'))
  53. model.add(Dense(units=84, activation='relu'))
  54. model.add(Dense(units=10, activation='softmax'))
  55. # 送入训练
  56. model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.05),metrics=['accuracy'])
  57. model.fit(X_train,Y_train,epochs=5000,batch_size=4096)
  58. # 评估测试集
  59. # 添加测试数据进行测试模型的泛化能力
  60. loss, accuracy = model.evaluate(X_test, Y_test)
  61. # 打印出损失和准确率
  62. print("loss"+str(loss))
  63. print("loss"+str(accuracy))


欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739

评价

瑾语

2023/5/11 9:12:52

111

python学习 1-安装

Ptyhon非常简单易用的面向对象的脚本语言,跨平台 入门简单python分2个版本 Python2、Python3。Python 2.7 将于 2020 年结...

python学习 2-基本语法

基础:python脚本语言,不需要编译(像C#、Java、PHP、C++需要编译成机器可识别的语言), 而直接由解释器解释,很多地方类似...

python学习 3-爬虫基本介绍 及简单实例

爬虫爬虫就是一只猪,蜘蛛。。 网络蜘蛛。互联网是一个网由各个网站组成。无数的蜘蛛就在网上到处爬,根据网址从一个网站爬...

python实例 1-日志抓取处理 补错(附日志小技巧)

有时候数据出了问题,可以从日志中恢复数据(如果你没记日志..没备份..→_→..)一、日志展示介绍个平常自己用的小方法,如...

python实例 2-12306抢票(一) 登陆

开坑年关将近,终于对12306下手了,,平安夜撸代码,攻克了12306的登陆 2018-12-24 22:16:00没错 这篇博客就写从零开始的异...

python安装pip以及使用pip安装requests等模块

pip很简单的介绍pip 是一个现代的,通用的 Python 包管理工具。提供了对 Python 包的查找、下载、安装、卸载的功能。如果想...

python数据集合区别

列表(list):可读写,值可以重复,有序排列,初始化语法:['tom',‘jerry’]元组(tuple):只读,值可以重复,...

python实例 2-12306抢票(二) 下单

第二篇 刷票与下单1.记住登陆上一篇写了登陆:http://www.tnblog.net/cz/article/details/162 为了方便调试 不让每次登陆都...

使用VS Code开发python

Vs Code开发Python可以很好的支持代码调试、智能提示、代码对齐等1:下载VS Codehttps://code.visualstudio.com/Downloadvs ...

python变量与命名

Python使用等号 ( = ) 作为赋值运算符,例如a = 66 就是一条赋值语句,作用就是将66赋值给变量a。Python是弱类型语言,弱类...

python关键字和内置函数

Python中包含了如下所示的关键字:上面这些关键字都不能作为变量名。另外,Python 3还提供了如下表所示的内置函数:也不能...

python基础输入和输出

Python使用print()函数向打印程序输出,采用input()函数接收程序输入。print()函数的语法格式如下:print(value,...,sep=&#...

python基本数据类型

Python包含两大类基本数据类型:数值类型、字符串类型,三大类组合数据类型:集合类型、序列类型和字典类型.数值类型:整型...

python中通过fake_useragent生成随机UserAgent

fake_useragent第三方库,来实现随机请求头的设置;GitHub ---> https://github.com/hellysmile/fake-useragent安...

python 升级pip

一条命令即可: python -m pip install --upgrade pip 安装成功后,一般是在python目录下面的Scripts里边的

python html编码解码

使用方法:html.escape(s)与html.unescape即可 import html s="<div>jsdlfjsl</div>" #html编...
这一世以无限游戏为使命!
排名
2
文章
633
粉丝
44
评论
93
docker中Sware集群与service
尘叶心繁 : 想学呀!我教你呀
一个bug让程序员走上法庭 索赔金额达400亿日元
叼着奶瓶逛酒吧 : 所以说做程序员也要懂点法律知识
.net core 塑形资源
剑轩 : 收藏收藏
映射AutoMapper
剑轩 : 好是好,这个对效率影响大不大哇,效率高不高
ASP.NET Core 服务注册生命周期
剑轩 : http://www.tnblog.net/aojiancc2/article/details/167
ICP备案 :渝ICP备18016597号-1
网站信息:2018-2025TNBLOG.NET
技术交流:群号656732739
联系我们:contact@tnblog.net
公网安备:50010702506256
欢迎加群交流技术