tnblog
首页
视频
资源
登录

Pytorch Flask服务部署图片识别(学习笔记)

2026人阅读 2024/1/5 14:16 总访问:3475668 评论:0 收藏:0 手机
分类: pytorch

Pytorch Flask服务部署图片识别(学习笔记)

Flask 简介


Flask是一个用Python编写的轻量级Web应用框架。
它简单易用,但同时也足够灵活和强大,能够支持复杂的Web应用。
由于其轻量级的特性,Flask非常适合用作在Web上部署机器学习模型的工具。

简单来讲:启动一个服务,根据传上来的东西进行预测并返回结果。

安装Flask

  1. python -m pip install flask

实践目录

文件或文件夹 描述
flower_data 训练的图像数据
best.pth 训练好的模型
flask_server.py 服务器端代码
flask_predict.py 客户端请求代码

服务器端


服务器对需要预处理的图片流程如下图所示:


flask_server.py代码如下所示:

  1. import io
  2. import json
  3. # flask 服务
  4. import flask
  5. import torch
  6. import torch
  7. import torch.nn.functional as F
  8. from PIL import Image
  9. from torch import nn
  10. #from torchvision import transforms as T
  11. from torchvision import transforms, models, datasets
  12. from torch.autograd import Variable
  13. # 初始化Flask app
  14. app = flask.Flask(__name__)
  15. model = None
  16. use_gpu = False
  17. # 加载模型进来
  18. def load_model():
  19. """Load the pre-trained model, you can use your model just as easily.
  20. """
  21. # 定义一个全局变量
  22. global model
  23. #这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
  24. model = models.resnet18()
  25. num_ftrs = model.fc.in_features
  26. model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 102类的分类任务
  27. #print(model) 加载模型
  28. checkpoint = torch.load('best.pth')
  29. # 加载权重参数
  30. model.load_state_dict(checkpoint['state_dict'])
  31. #将模型指定为测试格式
  32. model.eval()
  33. #是否使用gpu
  34. if use_gpu:
  35. model.cuda()
  36. # 数据预处理
  37. def prepare_image(image, target_size):
  38. """Do image preprocessing before prediction on any data.
  39. :param image: original image
  40. :param target_size: target image size
  41. :return:
  42. preprocessed image
  43. """
  44. #针对不同模型,image的格式不同,但需要统一至RGB格式
  45. if image.mode != 'RGB':
  46. image = image.convert("RGB")
  47. # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
  48. # 图片与训练尺寸大小一致
  49. image = transforms.Resize(target_size)(image)
  50. # 转tensor格式
  51. image = transforms.ToTensor()(image)
  52. # Convert to Torch.Tensor and normalize. mean与std (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致
  53. # 设置均值和标准差
  54. image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
  55. # Add batch_size axis.增加一个维度,用于按batch测试 本次这里一次测试一张
  56. # 举例:1*3*64*64
  57. image = image[None]
  58. if use_gpu:
  59. image = image.cuda()
  60. return Variable(image, volatile=True) #不需要求导
  61. # 开启服务 这里的predict是API路径、使用POST请求
  62. @app.route("/predict", methods=["POST"])
  63. def predict():
  64. # Initialize the data dictionary that will be returned from the view.
  65. #做一个标志,刚开始无图像传入时为false,传入图像时为true
  66. data = {"success": False}
  67. # 如果收到请求
  68. if flask.request.method == 'POST':
  69. #判断是否为图像
  70. if flask.request.files.get("image"):
  71. # Read the image in PIL format
  72. # 将收到的图像进行读取
  73. image = flask.request.files["image"].read()
  74. image = Image.open(io.BytesIO(image)) #二进制数据
  75. # 利用上面的预处理函数将读入的图像进行预处理
  76. image = prepare_image(image, target_size=(64, 64))
  77. # 放入模型中进行预测,softmax得到各个类别的概率
  78. preds = F.softmax(model(image), dim=1)
  79. # k找出类别前3高的
  80. results = torch.topk(preds.cpu().data, k=3, dim=1)
  81. # 结果转成cpu最后转成numpy
  82. results = (results[0].cpu().numpy(), results[1].cpu().numpy())
  83. #将data字典增加一个key,value,其中value为list格式
  84. data['predictions'] = list()
  85. # 遍历每一个预测结果
  86. for prob, label in zip(results[0][0], results[1][0]):
  87. #label_name = idx2label[str(label)]
  88. # label真实值,和probability概率值
  89. r = {"label": str(label), "probability": float(prob)}
  90. # 将预测结果添加至data字典
  91. data['predictions'].append(r)
  92. # Indicate that the request was a success.
  93. data["success"] = True
  94. # 将最终结果以json格式文件传出
  95. return flask.jsonify(data)
  96. """
  97. test_json = {
  98. "status_code": 200,
  99. "success": {
  100. "message": "image uploaded",
  101. "code": 200
  102. },
  103. "video":{
  104. "video_name":opt['source'].split('/')[-1],
  105. "video_path":opt['source'],
  106. "description":"1",
  107. "length": str(hour)+','+str(minute)+','+str(round(second,4)),
  108. "model_object_completed":model_flag
  109. }
  110. "status_txt": "OK"
  111. }
  112. response = requests.post(
  113. 'http://xxx.xxx.xxx.xxx:8090/api/ObjectToKafka/',,
  114. data={'json': str(test_json)})
  115. """
  116. if __name__ == '__main__':
  117. print("Loading PyTorch model and Flask starting server ...")
  118. print("Please wait until server has fully started")
  119. #先加载模型
  120. load_model()
  121. #再开启服务
  122. app.run(port='5012')


这里我开放的端口是5012,通过请求/predict链接,通过执行如下命令将程序跑起来:

  1. python flask_server.py

只要把Flask关了模型就没了,如果Flask一直开着的模型就一直都在跑。

客户端


客户端主要是上传一张image_06998.jpg的图片到服务器中去预测,代码如下:

  1. import requests
  2. import argparse
  3. # url和端口携程自己的
  4. flask_url = 'http://127.0.0.1:5012/predict'
  5. def predict_result(image_path):
  6. #传入本地图片
  7. image = open(image_path, 'rb').read()
  8. payload = {'image': image}
  9. #request发给server.
  10. r = requests.post(flask_url, files=payload).json()
  11. # 成功的话在返回.
  12. if r['success']:
  13. # 输出结果.
  14. for (i, result) in enumerate(r['predictions']):
  15. print('{}. {}: {:.4f}'.format(i + 1, result['label'],
  16. result['probability']))
  17. # 失败了就打印.
  18. else:
  19. print('Request failed')
  20. if __name__ == '__main__':
  21. parser = argparse.ArgumentParser(description='Classification demo')
  22. # 添加参数
  23. parser.add_argument('--file', default='./flower_data/train_filelist/image_06998.jpg', type=str, help='test image file')
  24. args = parser.parse_args()
  25. # 开始请求
  26. predict_result(args.file)
  1. python flask_predict.py


预测结果如下所示:


我们可以看到预测得最相似的label是34,准确率97%,我们去图片数据中找找这张图片的训练集验证一下。


训练的结果与预期的结果一致。


欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739

评价

C ?、?? 问号和2个问号的用法类型?、对象?

C# ?C# ???:单问号1.定义数据类型可为空。可用于对int,double,bool等无法直接赋值为null的数据类型进行null的赋值如这...

Python实例 1-日志抓取处理 补错附日志小技巧

有时候数据出了问题,可以从日志中恢复数据(如果你没记日志..没备份..→_→..)一、日志展示介绍个平常自己用的小方法,如...

C 数组拆分泛型

主要用到了泛型。泛型是c#2.0的一个新增加的特性,它为使用c#语言编写面向对象程序增加了极大的效力和灵活性。不会强行对值...

MySQL 视图的增删改 查

要显示视图的定义,需要在SHOWCREATEVIEW子句之后指定视图的名称, 我们先来创建几张表,完事后在进行演示:--用户信息表...

使用NPOI导出excel包括图片

Excl模板导出相信我们都会,那么模板上要导出图片呢?嗯~还是来个例子:准备工作:首先要引用NPOI包:然后获取数据集(我这...

ajaxSubmit异步上传图片嘘,外面都是假的

引用代码<scriptsrc="/Scripts/jquery.form.js"></script>js就在旁边img链接中,只不过大小为0x0,...

.NET MVC 使用百度编辑器详细教程:1配置编辑器

一、什么是百度编辑器百度编辑器UEditor是由百度web前端研发部开发一款应用于网站的编辑器,具有轻量,可定制,注重用户体...

使用jquery操作元素的css样式获取、修改等等

使用jquery操作元素的css样式(获取、修改等等) //1、获取和设置样式 $("#tow").attr("class")...

.net辗转java系列视野

.net辗转java系列(一)视野.net系java系其它语言C#Java框架.net Framework Standardjava se.net corejava eejave meJava S...

.NET MVC json对象或者json对象数组的序列化和反序列化

1、用JSON.stringify()将对象stuarr或者json数组stuarr序列化成字符串,然后提交给后台。$.post("/home/DoUpdate&quot...

.NET MVC json对象或者json对象数组的序列化和反序列化

1、用JSON.stringify()将对象stuarr或者json数组stuarr序列化成字符串,然后提交给后台。$.post("/home/DoUpdate&quot...

mui框架-移动端跳转以及传值的简单方法修改解决方法

纠结了两天的MUI跳转的问题,终于解决了 ,现在分享给大家,希望大家有什么坑的解决也给我分享分享 哈哈,废话不多说,上代...

MVC全局异常处理错误日记

1、在Filter文件夹中创建一个IsExceptionFilter类(类名随意取)2、使用3、在访问的页面控制器中添加几个错误4、在IsExcept...

MVC全局异常处理错误日记

1、在Filter文件夹中创建一个IsExceptionFilter类(类名随意取)2、使用3、在访问的页面控制器中添加几个错误4、在IsExcept...

Hbuilder打包APP的教程会操作的略过

首先项目必须是APP端的,可能讲解有点啰嗦,讲解准备的工具:HBuilderX(其他版本也可以,这里用X版本来讲解)、待测试手机...
这一世以无限游戏为使命!
排名
2
文章
636
粉丝
44
评论
93
docker中Sware集群与service
尘叶心繁 : 想学呀!我教你呀
一个bug让程序员走上法庭 索赔金额达400亿日元
叼着奶瓶逛酒吧 : 所以说做程序员也要懂点法律知识
.net core 塑形资源
剑轩 : 收藏收藏
映射AutoMapper
剑轩 : 好是好,这个对效率影响大不大哇,效率高不高
ASP.NET Core 服务注册生命周期
剑轩 : http://www.tnblog.net/aojiancc2/article/details/167
ICP备案 :渝ICP备18016597号-1
网站信息:2018-2025TNBLOG.NET
技术交流:群号656732739
联系我们:contact@tnblog.net
公网安备:50010702506256
欢迎加群交流技术