首页 > 其他分享 >:)模型保存为单一个pb文件

:)模型保存为单一个pb文件

时间:2023-02-26 11:24:33浏览次数:42  
标签:img graph 模型 pb print tf 为单 model

模型保存为单一个pb文件

背景

参考连接: https://www.yuque.com/g/jesse-ztr2k/nkke46/ss4rlv/collaborator/join?token=XUVZNORisVWEWyst# 

注意有些时候需要添加一个pb文件。 而不是tensorflow 提供的save 方法生成的一个目录里面包含了若干pb文件。

load时候直接填写这个目录即可。 但是有些时候需要合成一个pb文件。

 

tf2生成pb 目录描述

  1 目录结构

 -assets

-variables

-variables.data-00000-of-00001

-variables.index

-saved_model.pb

  2 作用

    其中 variables 记录模型参数 , pb文件记录模型结构

tf2 都是保存的 权重和 结构分开的, 如果需要兼容tf V1的代码,即导入一个pb文件,就需要 1 )保存常量计算图 2)frozen graph  pb格式。

tf1 生成pb脚本

环境准备:

tensorflow==1.15, tf-slim==1.1.0

https://github.com/tensorflow/models/tree/master/research/slim

注意 一定在tf v1 环境下生成pb

  1 import cv2
  2 import numpy as np
  3 import tensorflow as tf
  4 import os
  5 from tensorflow.python.framework import graph_util
  6 
  7 # 参考连接 https://blog.csdn.net/tensorflowforum/article/details/112352764 代码
  8 # 参考连接 参数详解:https://blog.csdn.net/weixin_43529465/article/details/124721583
  9 # https://blog.csdn.net/rain6789/article/details/78754516
 10 
 11 class SingleCnn(tf.keras.Model):
 12     def __init__(self):
 13         super(SingleCnn, self).__init__()
 14         # filters=1 卷积核数目,相当于卷积核的channel
 15         self.conv = tf.keras.layers.Conv2D(filters=1,
 16                                            kernel_size=[1, 1],
 17                                            # valid表示不填充, same表示合理填充
 18                                            padding='valid',
 19                                         # data_format='channels_last',-> 表示HWC,输入可以定义批次
 20                                            data_format='channels_last',
 21                                            use_bias=False,
 22                                            kernel_initializer=tf.keras.initializers.he_uniform(seed=None),
 23                                            name="conv")
 24 
 25     def call(self, inputs):
 26         x = self.conv(inputs)
 27         return x
 28 if __name__ == "__main__":
 29     # 构建场景输入数据
 30 
 31     # images=tf.random.uniform((1, 300, 300, 3))
 32 
 33     # 图像数据
 34     imagefile = r"catanddog\cat\5.JPG"
 35     img = cv2.imread(imagefile)
 36     img = cv2.resize(img, (64, 64))
 37     img = np.expand_dims(img, axis=0)
 38     print(img.shape, type(img), img.dtype)
 39 
 40     # 未量化的model不支持int32和int8
 41     # img = img.astype(np.int32)
 42     img = tf.convert_to_tensor(img, np.float32)
 43     print(img.shape, type(img), img.dtype)
 44     singlecnn = SingleCnn()
 45 
 46     output = singlecnn(img)
 47     print(output.shape, type(output))
 48     print(output[0][2:10][2:6])
 49     # =========== ckpt保存 with session的写法tf2 已不再使用 ===========
 50     # with tf.Session(graph=tf.Graph()) as sess:
 51     #     constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
 52 
 53     # 保存参考 https://zhuanlan.zhihu.com/p/146243327
 54     # save_format='tf' 代表保存pb
 55     # singlecnn.save('./pbmodel/singlecnn', save_format='tf')
 56     # tf.saved_model.save(singlecnn, './pbmodel/singlecnn')
 57     tf.keras.models.save_model(singlecnn, './pbmodel/singlecnn_0',
 58                                save_format="tf",
 59                                include_optimizer=False, save_traces=False)
 60 
 61     # 加载模型 验证可以加载
 62     new_model = tf.keras.models.load_model('./pbmodel/singlecnn_0', compile=False)
 63     # new_model = tf.saved_model.load('./pbmodel/singlecnn_0')
 64     # output_ = new_model(img)
 65     # # print(output_.shape, output_[0][2:6][2:6])
 66     # print(output_.shape)
 67     #
 68     # 查看结构
 69     new_model.summary()
 70 
 71     # print("----------------")
 72     # # 加载模型
 73     # saved_model = tf.saved_model.load('./pbmodel/singlecnn_0')
 74     # # 将模型转换为pb格式    还是目录方法。
 75     # converter = tf.saved_model.save(saved_model, "model.pb")
 76 
 77     def change_pb(pretrained_model):
 78         """tf v1 选用tf1 跑这个脚本生成pb"""
 79         from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
 80         # 重点
 81         # Convert Keras model to ConcreteFunction
 82         # MobileNet is a function
 83         full_model = tf.function(lambda x: pretrained_model(x))
 84 
 85         # 指定shape和dtype对tf function进行重新追踪
 86         full_model = full_model.get_concrete_function(
 87             tf.TensorSpec(pretrained_model.inputs[0].shape, pretrained_model.inputs[0].dtype))
 88 
 89         # Get frozen ConcreteFunction,将计算图中的变量及其取值通过常量的方式保存
 90         frozen_func = convert_variables_to_constants_v2(full_model)
 91         frozen_func.graph.as_graph_def()
 92 
 93         layers = [op.name for op in frozen_func.graph.get_operations()]
 94         print("-" * 50)
 95         print("Frozen model layers: ")
 96         for layer in layers:
 97             print(layer)
 98 
 99         print("-" * 50)
100         print("Frozen model inputs: ")
101         print(frozen_func.inputs)
102         print("Frozen model outputs: ")
103         print(frozen_func.outputs)
104 
105         # Save frozen graph from frozen ConcreteFunction to hard drive
106         # as_text: If True, writes the graph as an ASCII proto; otherwise, The graph is written as a text proto
107         tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
108                           logdir="./frozen_models",
109                           name="frozen_graph.pb",
110                           as_text=True)
111 
112 
113 change_pb(new_model)
model_getpb

python download_and_convert_data.py --dataset_name=flowers --dataset_dir="tmp/dataset"

 

 

标签:img,graph,模型,pb,print,tf,为单,model
From: https://www.cnblogs.com/lx63blog/p/17156312.html

相关文章

  • c: machine0 - 机器语言的模型机
    c: machine0-机器语言的模型机    一、源码1[wit@eaglesrc]$catmachine0.c2#include<stdio.h>3#include<stdlib.h>4#include<string.h>5......
  • ES 数据副本模型
    除了可以承受更多的并发流量、存储海量数据外,分布式系统另外一个优点就是:利用数据备份来防止数据丢失。但也正是由于数据副本的存在,也引入了一些其他的问题,比如,如何选取主......
  • 22、模型的保存与读取
    1、模型的保存1'''1、模型的保存'''2importtorch3importtorchvision45vgg16=torchvision.models.vgg16(pretrained=False)6#保存方式1:保存网络模型......
  • 899~900 Maven 指令的生命周期,概念模型图
    Maven指令的生命周期maven对项目构建过程分为三套相互独立的生命周期,请注意这里说的是“三套”,而且“相互独立”,这三套生命周期分别是:CleanLifecy......
  • 预训练语言模型基础知识串讲
    预训练语言模型基础知识串讲_Bolin-BGI的CSDN博客 ......
  • 在Google的TPU上训练Fashion MNIST图像识别模型
    作者|张强今天我们要训练的模型是基于Keras框架,来训练FashionMNIST图像识别模型,该模型和MNIST是一样的分类数量。​​MNIST​​​的分类是0到9的十个数字​​​FashionMN......
  • Unity 进去区域时显示模型,离开区域时候隐藏模型
    1.在“MainCamera”下面创建一个Cube,调整大小,并将其Tag指定为“Player”(如下图所示)。 2.新建一个脚本,命名为“Tourcamera”,用来控制相机的移动(代码如下),并将其挂......
  • 浪潮以AI算力服务助力,网易大模型问鼎中文语言评测分类冠军
    日前,网易伏羲中文预训练大模型“玉言”登顶中文语言理解权威测评基准CLUE分类任务榜单,在多项任务上超过人类水平。其具备的自然语言处理能力,可应用于语言助手文本创作、新闻......
  • C#根据模型生成图片
    usingSystem.Collections;usingSystem.Collections.Generic;usingUnityEngine;usingUnityEditor;publicclassExportPng:MonoBehaviour{publicGameObj......
  • T-SQL——将字符串转为单列
    目录0.背景1.使用STRING_SPLIT函数2.自定义分裂函数3.使用示例shanzm-2023年2月22日0.背景代码中执行存储过程,参数是多个且不确定数量,期望SQL查询时使用该参数作......