
Pytorch Flask服务部署图片识别(学习笔记)
Flask 简介
Flask是一个用Python编写的轻量级Web应用框架。
它简单易用,但同时也足够灵活和强大,能够支持复杂的Web应用。
由于其轻量级的特性,Flask非常适合用作在Web上部署机器学习模型的工具。
简单来讲:启动一个服务,根据传上来的东西进行预测并返回结果。
安装Flask
python -m pip install flask
实践目录
文件或文件夹 | 描述 |
---|---|
flower_data |
训练的图像数据 |
best.pth |
训练好的模型 |
flask_server.py |
服务器端代码 |
flask_predict.py |
客户端请求代码 |
服务器端
服务器对需要预处理的图片流程如下图所示:
flask_server.py
代码如下所示:
import io
import json
# flask 服务
import flask
import torch
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
#from torchvision import transforms as T
from torchvision import transforms, models, datasets
from torch.autograd import Variable
# 初始化Flask app
app = flask.Flask(__name__)
model = None
use_gpu = False
# 加载模型进来
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
# 定义一个全局变量
global model
#这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 102类的分类任务
#print(model) 加载模型
checkpoint = torch.load('best.pth')
# 加载权重参数
model.load_state_dict(checkpoint['state_dict'])
#将模型指定为测试格式
model.eval()
#是否使用gpu
if use_gpu:
model.cuda()
# 数据预处理
def prepare_image(image, target_size):
"""Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
"""
#针对不同模型,image的格式不同,但需要统一至RGB格式
if image.mode != 'RGB':
image = image.convert("RGB")
# Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
# 图片与训练尺寸大小一致
image = transforms.Resize(target_size)(image)
# 转tensor格式
image = transforms.ToTensor()(image)
# Convert to Torch.Tensor and normalize. mean与std (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致
# 设置均值和标准差
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis.增加一个维度,用于按batch测试 本次这里一次测试一张
# 举例:1*3*64*64
image = image[None]
if use_gpu:
image = image.cuda()
return Variable(image, volatile=True) #不需要求导
# 开启服务 这里的predict是API路径、使用POST请求
@app.route("/predict", methods=["POST"])
def predict():
# Initialize the data dictionary that will be returned from the view.
#做一个标志,刚开始无图像传入时为false,传入图像时为true
data = {"success": False}
# 如果收到请求
if flask.request.method == 'POST':
#判断是否为图像
if flask.request.files.get("image"):
# Read the image in PIL format
# 将收到的图像进行读取
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image)) #二进制数据
# 利用上面的预处理函数将读入的图像进行预处理
image = prepare_image(image, target_size=(64, 64))
# 放入模型中进行预测,softmax得到各个类别的概率
preds = F.softmax(model(image), dim=1)
# k找出类别前3高的
results = torch.topk(preds.cpu().data, k=3, dim=1)
# 结果转成cpu最后转成numpy
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
#将data字典增加一个key,value,其中value为list格式
data['predictions'] = list()
# 遍历每一个预测结果
for prob, label in zip(results[0][0], results[1][0]):
#label_name = idx2label[str(label)]
# label真实值,和probability概率值
r = {"label": str(label), "probability": float(prob)}
# 将预测结果添加至data字典
data['predictions'].append(r)
# Indicate that the request was a success.
data["success"] = True
# 将最终结果以json格式文件传出
return flask.jsonify(data)
"""
test_json = {
"status_code": 200,
"success": {
"message": "image uploaded",
"code": 200
},
"video":{
"video_name":opt['source'].split('/')[-1],
"video_path":opt['source'],
"description":"1",
"length": str(hour)+','+str(minute)+','+str(round(second,4)),
"model_object_completed":model_flag
}
"status_txt": "OK"
}
response = requests.post(
'http://xxx.xxx.xxx.xxx:8090/api/ObjectToKafka/',,
data={'json': str(test_json)})
"""
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
#先加载模型
load_model()
#再开启服务
app.run(port='5012')
这里我开放的端口是5012
,通过请求/predict
链接,通过执行如下命令将程序跑起来:
python flask_server.py
只要把Flask关了模型就没了,如果Flask一直开着的模型就一直都在跑。
客户端
客户端主要是上传一张image_06998.jpg
的图片到服务器中去预测,代码如下:
import requests
import argparse
# url和端口携程自己的
flask_url = 'http://127.0.0.1:5012/predict'
def predict_result(image_path):
#传入本地图片
image = open(image_path, 'rb').read()
payload = {'image': image}
#request发给server.
r = requests.post(flask_url, files=payload).json()
# 成功的话在返回.
if r['success']:
# 输出结果.
for (i, result) in enumerate(r['predictions']):
print('{}. {}: {:.4f}'.format(i + 1, result['label'],
result['probability']))
# 失败了就打印.
else:
print('Request failed')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification demo')
# 添加参数
parser.add_argument('--file', default='./flower_data/train_filelist/image_06998.jpg', type=str, help='test image file')
args = parser.parse_args()
# 开始请求
predict_result(args.file)
python flask_predict.py
预测结果如下所示:
我们可以看到预测得最相似的label是34
,准确率97%
,我们去图片数据中找找这张图片的训练集验证一下。
训练的结果与预期的结果一致。
欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739
评价
排名
2
文章
636
粉丝
44
评论
93
docker中Sware集群与service
尘叶心繁 : 想学呀!我教你呀
一个bug让程序员走上法庭 索赔金额达400亿日元
叼着奶瓶逛酒吧 : 所以说做程序员也要懂点法律知识
.net core 塑形资源
剑轩 : 收藏收藏
映射AutoMapper
剑轩 :
好是好,这个对效率影响大不大哇,效率高不高
一个bug让程序员走上法庭 索赔金额达400亿日元
剑轩 : 有点可怕
ASP.NET Core 服务注册生命周期
剑轩 :
http://www.tnblog.net/aojiancc2/article/details/167
ICP备案 :渝ICP备18016597号-1
网站信息:2018-2025TNBLOG.NET
技术交流:群号656732739
联系我们:contact@tnblog.net
公网安备:
50010702506256


欢迎加群交流技术