想要部署深度学习模型?试试 FLASK 构建 REST API 部署
首发于 towardsdeeplearning 想必大家都训练出过比较好玩的模型,但是是不是想要向别人提供下接口或者自己试着玩下,这时候就需要涉及到部署模型了,这里,我们将使用 Flask 部署 PyTorch 模型,并构建用于模型推理的REST API。要注意的是: 使用 Flask 是为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的场景。对高性能有要求的场景,可以使用 TorchScript,下次再说。环境安装:pip install Flask==1.0.3 torch==1.2.0 torchvision-0.3.0假设我们的场景是上传图片进行返回图片的分类结果,那么我们定义下 API 形式,请求和响应类型。将 API endpoint 将位于 /predict,接受带有包含图像的文件参数的 HTTP POST 请求。响应将是包含预测结果的 JSON 响应:{"class_id": "xx", "class_name": "yy"}首先先复习下,构建一个简单的 Web 服务器
from flask import Flask app = Flask(__name__) @app.route(/) def hello(): return welcome to !运行FLASK_ENV=development FLASK_APP=app.py flask run访问 :5000/ 可以看到 welcome tohttp://towardsdeeplearning.com !可以查看 flask 文档,熟悉下 post。为了符合上边 api 的定义,我们需要修改下代码:
from flask import Flask, jsonify app = Flask(__name__) @app.route(/predict, methods=[POST]) def predict(): return jsonify({class_id: IMAGE_NET_XXX, class_name: Cat})到此,骨干网络已经搭建完毕。还缺少什么呢?上边这个是返回的json是写死的,但是实际上要根据 post 的图片进行预测。图片通过 HTTP POST 请求传递过来, 可以通过下面这个方式获取
@app.route(/predict, methods=[POST]) def predict(): if request.method == POST: # we will get the file from the request file = request.files[file]搭建下预测的代码,这里使用了 mnasnet ,可以在 torchvision 导入预训模型。 mnasnet 的输入图片是 3 通道的 RGB 模型,大小为 224 x 224。其实熟悉 pytorch 的同学应该很容易写出前向预测的代码的。
import io import torchvision.transforms as transforms from PIL import Image def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # 接收的图片是 bytes 转成图片格式,再进行转换 image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) from torchvision import models model = models.mnasnet1_0(pretrained=True) model.eval() def predict(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, pred = outputs.max(1) return predpredict 的结果是类别的id,为了方便显示,我们需要进行转成文字, 就是具体的类别,狗狗啊这样人类可读性好的。
import json imagenet_class_index = json.load(open(imagenet_class_index.json)) def predict(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]最后,整理的代码如下
import io import json import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request from torchvision import models app = Flask(__name__) imagenet_class_index = json.load(open(./imagenet_class_index.json, "r")) model = models.mnasnet1_0(pretrained=True) model.eval() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] @app.route(/predict, methods=[POST]) def predict(): if request.method == POST: file = request.files[file] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({class_id: class_id, class_name: class_name}) if __name__ == __main__: app.run()使用下面的命令运行。FLASK_ENV=development FLASK_APP=app.py flask run使用下面的测试代码,进行测试。
import requests resp = requests.post(":5000/predict", files={"file": open(dog.jpg,rb)}) print( resp.json() # {"class_id": "xx", "class_name": "xx"}