首页 > 编程语言 >Tensorflow Serving部署及客户端访问编程实践

Tensorflow Serving部署及客户端访问编程实践

时间:2024-08-06 23:25:25浏览次数:19  
标签:Serving image label Tensorflow TensorFlow model path 客户端

昨天我们实现了Tensorflow.js的花卉识别程序,它的优点是不需要服务器支持,在客户端就可以完成花卉识别,使用非常方便,但也存在一些缺点。对于很多深度学习的应用来说,由于其训练模型复杂、计算量大,所以,一般来说,仍然需要服务器支持。下面仍然以花卉识别为例,介绍如何部署Tensorflow Serving及客户端编程。

TensorFlow Serving 是由 Google 开发和维护的开源项目,是 TensorFlow 生态系统的一部分,专门用于高效地部署和服务机器学习模型,具有高性能、灵活性、易于集成、可扩展性、易于管理和健壮性等多方面的优点。最重要的是它与 TensorFlow 紧密集成,实现了与 TensorFlow 生态系统无缝集成,支持 TensorFlow 模型的完整生命周期管理,从训练到部署再到监控。并且能够直接加载和使用 TensorFlow 的 SavedModel 格式,无需额外的转换步骤。
相对于许多通用的 web 服务器和 API 服务器(如 Flask、Django、FastAPI 等),但 TensorFlow Serving 专门针对机器学习模型的服务进行了优化,包括高效的内存管理、请求批处理、多线程处理等,能够在高并发和高负载的场景下表现出色。

这里不介绍Tensorflow Serving的安装,只介绍与编程有关的部署等问题。

文末附完整源代码链接。

一、服务端部署训练模型

1. 配置模型

按如下目录存放训练SavedModel模型:

/path/to/your/model/
└── your_model/
    └── 1/
        ├── saved_model.pb
        └── variables/
            ├── variables.data-00000-of-00001
            └── variables.index

2. 启动 TensorFlow Serving

执行以下命令:

tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=your_model --model_base_path=/path/to/your/model/your_model

这条命令用于启动 TensorFlow Serving 服务器,加载指定的模型,并配置其服务端口和 API 端口。以下是每个参数的详细解释:

3. 命令和参数解释

tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=your_model --model_base_path=/path/to/your/model/your_model
  • tensorflow_model_server: 这是启动 TensorFlow Serving 服务器的命令。
  • --port=8500: 指定 gRPC API 的端口号。gRPC 是一种高性能的远程过程调用(RPC)框架,适用于需要高吞吐量和低延迟的应用场景。
  • --rest_api_port=8501: 指定 RESTful API 的端口号。RESTful API 基于 HTTP 协议,使用起来简单且广泛应用,方便客户端通过 HTTP 请求与 TensorFlow Serving 进行交互。
  • --model_name=your_model: 指定模型的名称。在服务中使用这个名称来引用和请求这个模型。这个名称可以在客户端请求中用来标识和调用特定的模型。
  • --model_base_path=/path/to/your/model/your_model: 指定模型所在的目录路径。TensorFlow Serving 会在这个目录中查找并加载模型。该路径应包含模型的文件和子目录。

二、客户端程序

1. 使用gRPC协议访问服务器

下面的代码实现了gRPC客户端 与 TensorFlow Serving 服务器交互。客户端对图片进行预处理后,向服务器发送请求,服务器完成花卉识别后,向客户端返回结果。以下是对关键代码的解释:

(1)图像预处理函数

def process_image(image: np.ndarray) -> np.ndarray:
    image_tensor = tf.convert_to_tensor(image)
    image_resized = tf.image.resize(image_tensor, (224, 224))
    image_resized /= 255

    return image_resized.numpy()
  • 将输入的图像数组转换为 TensorFlow 张量。
  • 调整图像大小为 (224, 224)
  • 将图像归一化到 [0, 1] 范围。
  • 返回预处理后的图像数组。

(2)加载和预处理图像的函数

def load_image(image_path):
    im = Image.open(image_path)
    image_arr = np.asarray(im)
    processed_image = process_image(image_arr)
    processed_image = np.expand_dims(processed_image, 0)
        
    return processed_image
  • 加载图像文件并转换为 NumPy 数组。
  • 调用 process_image 进行预处理。
  • 将图像扩展为 (1, 224, 224, 3) 形状,适应批处理输入。
  • 返回预处理后的图像。

(3)加载标签映射的函数

def load_label_map(label_map_path):
    with open(label_map_path, 'r', encoding='utf-8') as f:
        label_map = json.load(f)
    return label_map
  • 从 JSON 文件中加载标签映射。
  • 返回标签映射的字典。

(4)创建 gRPC 频道和存根

channel = grpc.insecure_channel('your server:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
  • 创建一个 gRPC 频道,连接到 TensorFlow Serving 服务。
  • 创建一个存根,用于与 TensorFlow Serving 进行通信。

(5)创建预测请求

request = predict_pb2.PredictRequest()
request.model_spec.name = 'ai_flower'
request.model_spec.signature_name = 'serving_default'
  • 创建一个 PredictRequest 对象。
  • 设置模型名称 ai_flower 和签名名称 serving_default

(6)读取和预处理图像

image_path = 'test_images/image_00250.jpg'
input_image = load_image(image_path)
  • 设置图像路径。
  • 调用 load_image 函数读取和预处理图像。

(7)设置请求输入张量

request.inputs['keras_layer_input'].CopyFrom(
    tf.make_tensor_proto(input_image, shape=input_image.shape))
  • 将预处理后的图像设置为请求的输入张量。

(8)发送请求并获取响应

response = stub.Predict(request)
  • 发送预测请求并获取响应。

(9)提取预测结果

output_tensor_name = 'dense'  # 修改为实际的键名
if output_tensor_name in response.outputs:
    predictions = tf.make_ndarray(response.outputs[output_tensor_name])
else:
    print(f"Output tensor '{output_tensor_name}' not found in the response.")
    predictions = []
  • 假设输出张量的键名是 dense,从响应中提取预测结果。
  • 如果键名不同,请根据实际情况进行修改。

2. 使用REST API协议访问服务器

与上述使用gRPC协议访问服务器实现的功能一样。以下只对有区别代码的进行解释:

(1)服务器 URL

server_url = 'http://your_server:8501/v1/models/ai_flower:predict'
  • 指定 TensorFlow Serving 服务器的 URL,发送预测请求到 ai_flower 模型的 predict 端点。

(2)发送 POST 请求到服务器

response = requests.post(server_url, json=data)
  • 通过 POST 请求将图像数据发送到 TensorFlow Serving 服务器。

(3)检查响应状态

if response.status_code == 200:
    result = response.json()
    predictions = np.array(result['predictions'])
    label_map = load_label_map('label_map.json')

    top_k = 5
    top_indices = np.argsort(predictions[0])[-top_k:][::-1]
    for i in top_indices:
        label_id = i + 1
        label_name = label_map.get(str(label_id), 'Unknown')
        confidence = predictions[0][i]
        print(f"label_id: {label_id}, Label: {label_name}, Confidence: {confidence:.4f}")
else:
    print(f"Request failed with status code {response.status_code}")
    print("Response:", response.text)
  • 检查响应状态码是否为 200(即请求成功)。
  • 解析响应 JSON 数据,提取预测结果。
  • 加载标签映射文件。
  • 获取前 5 位预测结果,打印每个预测类别的标签和信心分数。
  • 如果请求失败,打印状态码和响应内容。

完整源代码

标签:Serving,image,label,Tensorflow,TensorFlow,model,path,客户端
From: https://blog.csdn.net/playlaugh/article/details/140940200

相关文章

  • 部署CPU与GPU通用的tensorflow:Anaconda环境
      本文介绍在Anaconda环境中,下载并配置Python中机器学习、深度学习常用的新版tensorflow库的方法。  在之前的两篇文章PythonTensorFlow深度学习回归代码:DNNRegressor与PythonTensorFlow深度神经网络回归:keras.Sequential中,我们介绍了利用Python中的tensorflow库,实现机器学......
  • RabbitMQ(三)Java客户端
    1.快速入门在idea里面创建两个springboot项目,一个模块是consumer,一个是publisher两者有自己的启动类,继承同一父工程的pom。父工程的pom.xml<?xmlversion="1.0"encoding="UTF-8"?><projectxmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http......
  • SqlDbx客户端连接服务器Oracle数据库
    查了很多文章,介绍的不对,走了好多弯路,最后整理一下,供参考一、下载Oracle客户端1、SqlDbx如果是32位的,客户端也要下载32位的2、Oracle客户端版本要和服务端版本一致(本例用的是12.1.0.2.0)3、32位客户端下载地址:https://www.oracle.com/database/technologies/instant-client/mic......
  • [Microsoft][ODBC Driver 17 for SQL Server]TCP 提供程序:错误代码 0x2746 - 客户端无
    我正在尝试运行odoo项目ubuntu:22.04pyodbc==4.0.28python:3.8opensslversion:OpenSSL1.1.1p21Jun2022**ERROR:DIAG[08001][Microsoft][ODBCDriver17forSQLServer]TCPProvider:Errorcode0x2746DIAG[08001][Microsoft][ODBCDriver......
  • 安装CPU版本的TensorFlow教程
    在这篇博客中,我将详细介绍如何在Conda虚拟环境中安装CPU版本的TensorFlow。第一步,首先在安装前你需要检查你的电脑中是否安装VisualStudio,在应用中搜索visual,如下图,如果有就跳到下一步,没有的话就跟着下面步骤安装。 VisualStudio安装,复制链接到浏览器打开,选择适合你电脑......
  • 基于Java swing + MySQL电影院订票与管理系统,分为客户端和服务端
    一、需求分析电影院购票与管理系统......
  • Spring HTTP 客户端
    前言Spring提供了一些HTTP客户端类,可以方便地发起HTTP请求。如果需要了解更多SpringWeb的相关内容,可参考SpringWeb指南。RestTemplateRestTemplate是SpringWeb模块提供的一个同步的HTTP客户端,在Spring5(SpringBoot2)中可使用。它提供了一系列的HTTP请......
  • 书籍分享《TensorFlow机器学习实战指南》从入门到实战,免费领取!
    Google公司开发的TensorFlow深度学习库因其简单易学、应用场景广泛已经快成为各家公司开展人工智能研究的标配了。TensorFlow机器学习实战指南作者:NickMcClure,资深数据科学家,就职于美国西雅图PayScale公司,曾经在Zillow公司和Caesar’sEntertainment公司工作,获得蒙......
  • 【眼疾病识别】图像识别+深度学习技术+人工智能+卷积神经网络算法+计算机课设+Python+
    一、项目介绍眼疾识别系统,使用Python作为主要编程语言进行开发,基于深度学习等技术使用TensorFlow搭建ResNet50卷积神经网络算法,通过对眼疾图片4种数据集进行训练('白内障','糖尿病性视网膜病变','青光眼','正常'),最终得到一个识别精确度较高的模型。然后使用Django框架开发Web网......
  • SSH客户端客户端工具都有哪些?
    以下是一些常见的SSH客户端:根据客户端划分常用的PuTTY特点:PuTTY是一款开源的SSH和Telnet客户端,以其轻量级和便捷性而广受欢迎。它支持多种协议,包括SSH、Telnet、rlogin和原始TCP连接。支持平台:Windows、Mac和Linux。功能:提供简洁直观的用户界面,支持SSH密钥认证,增强连接......