首页 > 其他分享 >TensorFlow tfrecord的解析

TensorFlow tfrecord的解析

时间:2023-11-17 17:44:57浏览次数:34  
标签:tfrecord default feature value length tf TensorFlow 解析 data

import tensorflow as tf
import json
aa = {
"label": {
"binary_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"triple_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"four_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"context": {
"item_code": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"query": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"prod_name": {
"is_use": 0,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"feature": {
"query_ids": {
"is_use": 1,
"data_type": "int64_list",
"default_value": 0,
"feature_length": "var_length",
"feature_type": "int_sequence",
"vocab_list": ["unk", "1", "2", "3", "4", "5"],
"shape": [-1, -1],
"preprocess": "pad_sequence",
"description": "搜索词token id sequence"
},
"title_ids": {
"is_use": 1,
"data_type": "int64_list",
"default_value": 0,
"feature_length": "var_length",
"feature_type": "int_sequence",
"vocab_list": ["unk", "1", "2", "3", "4", "5"],
"shape": [-1, -1],
"preprocess": "pad_sequence",
"description": "商品标题token id sequence"
},
"query_token_ids": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 1],
"preprocess": "pad_sequence",
"description": "搜索词原始token id sequence"
},
"ic_token_ids": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 1],
"preprocess": "pad_sequence",
"description": "商品标题原始token id sequence"
}
}
}




def get_feature_description(feature_config):
feature_description = {}
# 1.解析特征配置
for feature_name, feature_stats in feature_config["feature"].items():
if feature_stats["is_use"] != 1:
continue
feature_length = feature_stats["feature_length"]
feature_shape = feature_stats["shape"][1]
default_single_value = feature_stats["default_value"]
data_type = feature_stats["data_type"]
if feature_length == "fixed_length":
if data_type == "int64":
if feature_shape > 1:
default_value = [int(default_single_value)] * feature_shape
else:
default_value = int(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=default_value)
elif data_type == "float32":
if feature_shape > 1:
default_value = [float(default_single_value)] * feature_shape
else:
default_value = float(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=default_value)
elif data_type == "string":
if feature_shape > 1:
default_value = [str(default_single_value)] * feature_shape
else:
default_value = str(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=default_value)
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")
elif feature_length == "var_length":
if data_type == "int64_list":
feature_description[feature_name] = tf.io.VarLenFeature(tf.int64)
elif data_type == "float32_list":
# TODO
pass
elif data_type == "string_list":
# TODO
pass
else:
raise ValueError(f"var_length datatype {data_type} now is not supported!")
else:
# TODO
raise ValueError(f"feature_length {feature_length} now is not supported!")
# 2.解析label配置
for label_name, label_stats in feature_config["label"].items():
if label_stats["is_use"] != 1:
continue
feature_length = label_stats["feature_length"]
data_type = label_stats["data_type"]
feature_shape = label_stats["shape"][1]
default_value = label_stats["default_value"]
assert data_type == "int64"
assert feature_length == "fixed_length"
assert feature_shape == 1
feature_description[label_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))

# 3.解析context配置
for context_name, context_stats in feature_config["context"].items():
if context_stats["is_use"] != 1:
continue
feature_length = context_stats["feature_length"]
feature_shape = context_stats["shape"][1]
default_value = context_stats["default_value"]
data_type = context_stats["data_type"]
assert feature_shape == 1
if feature_length == "fixed_length":
if data_type == "int64":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))
elif data_type == "float32":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=float(default_value))
elif data_type == "string":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=str(default_value))
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")

return feature_description

feature_description = get_feature_description(aa)
# 定义解析函数
def parse_tfrecord_fn(example):
parsed_example = tf.io.parse_single_example(example, feature_description)

return parsed_example


# 指定要解析的TFRecord文件路径
tfrecord_file = './data/eval.tfrecord'

# 创建TFRecordDataset对象并应用解析函数
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(parse_tfrecord_fn)

# 遍历数据集并打印样本
for example in dataset:
print(example)
print(example['query_ids'].values.numpy())
print(example['title_ids'].values.numpy())
print(example['binary_label'].numpy())
print(example['four_label'].numpy())
print(example['ic_token_ids'].numpy())
print(example['item_code'].numpy())
print(example['query'].numpy())
print(example['query_token_ids'].numpy())
print(example['triple_label'].numpy())
print(".............................................................")



# 生成tfrecord
import tensorflow as tf

# 准备数据
data = {
'query': 'apple',
'query_ids': [1, 2, 3],
'title_ids': [4, 5, 6],
'binary_label': 1,
'four_label': 2,
'ic_token_ids': '12345',
'item_code': 123456,
'query_token_ids': '67890',
'triple_label': 3
}

# 指定输出的 TFRecord 文件路径
train_file = './train_test.tfrecord'

# 创建 TFRecordWriter 对象
with tf.io.TFRecordWriter(train_file) as writer:
# 创建 Example 对象
feature = {
'query': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['query'].encode()])),
'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=data['query_ids'])),
'title_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=data['title_ids'])),
'binary_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['binary_label']])),
'four_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['four_label']])),
'ic_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['ic_token_ids'].encode()])),
'item_code': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['item_code']])),
'query_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['query_token_ids'].encode()])),
'triple_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['triple_label']])),
}
# 创建 Features 对象
features = tf.train.Features(feature=feature)
# 创建 Example 对象
example_proto = tf.train.Example(features=features)
# 序列化 Example 对象并写入 TFRecord 文件
writer.write(example_proto.SerializeToString())

print(f'Generated TFRecord file: {train_file}')




# val_file = './data/train.tfrecord'
#
# # 创建 TFRecordWriter 对象
# with tf.io.TFRecordWriter(val_file) as writer:
# # 遍历数据并写入 TFRecord 文件
# for example in data:
# # 创建 Example 对象
# feature = {
# 'query': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['query'].encode()])),
# 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=example['query_ids'])),
# 'title_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=example['title_ids'])),
# 'binary_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['binary_label']])),
# 'four_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['four_label']])),
# 'ic_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['ic_token_ids'].encode()])),
# 'item_code': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['item_code']])),
# 'query_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['query_token_ids'].encode()])),
# 'triple_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['triple_label']])),
# }
# # 创建 Features 对象
# features = tf.train.Features(feature=feature)
# # 创建 Example 对象
# example_proto = tf.train.Example(features=features)
# # 序列化 Example 对象并写入 TFRecord 文件
# writer.write(example_proto.SerializeToString())
#
# print(f'Generated TFRecord file: {val_file}')


# from wordcloud import STOPWORDS
#
# import re
#
# from collections import defaultdict
#
# item_dict = defaultdict(int)
#
# with open("./data/title_key.txt", "r", encoding="utf-8") as f,open("./data/query_key.txt", "r", encoding="utf-8") as f1,open("./data/data.txt", "w", encoding="utf-8") as out:
# for line in f:
# key, num = line.strip("\n").split("\t")
# if key.strip():
# item_dict[key] = int(num)
# for line in f1:
# key, num = line.strip("\n").split("\t")
# if key.strip():
# item_dict[key]+=int(num)
# sorted_dict = sorted(item_dict.items(), key=lambda x: x[1], reverse=True)
# for key,num in sorted_dict:
# if num>4 and key not in STOPWORDS:
# if re.search("^[a-z0-9]+$",key):
# out.write("{}\t{}\n".format(key, num))
#
#
#


# bert的tfrecord
import sys

import tensorflow as tf
import json
aa = {
"label": {
"binary_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"triple_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"four_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"context": {
"item_code": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"query": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"prod_name": {
"is_use": 0,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"feature": {
"query_input_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 10],
"preprocess": "pad_sequence",
"description": "原始搜索词input_ids"
},
"query_attention_mask": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 10],
"preprocess": "pad_sequence",
"description": "原始搜索词attention_mask"
},
"query_token_type_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 10],
"preprocess": "pad_sequence",
"description": "原始搜索词token_type_ids"
},
"title_input_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 30],
"preprocess": "pad_sequence",
"description": "标题input_ids"
},
"title_attention_mask": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 30],
"preprocess": "pad_sequence",
"description": "标题attention_mask"
},
"title_token_type_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 30],
"preprocess": "pad_sequence",
"description": "标题token_type_ids"
}
}
}





def get_feature_description(feature_config):
feature_description = {}
# 1.解析特征配置
for feature_name, feature_stats in feature_config["feature"].items():
if feature_stats["is_use"] != 1:
continue
feature_length = feature_stats["feature_length"]
feature_shape = feature_stats["shape"][1]
default_single_value = feature_stats["default_value"]
data_type = feature_stats["data_type"]
if feature_length == "fixed_length":
if data_type == "int64":
if feature_shape > 1:
default_value = [int(default_single_value)] * feature_shape
else:
default_value = int(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=default_value)
elif data_type == "float32":
if feature_shape > 1:
default_value = [float(default_single_value)] * feature_shape
else:
default_value = float(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=default_value)
elif data_type == "string":
if feature_shape > 1:
default_value = [str(default_single_value)] * feature_shape
else:
default_value = str(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=default_value)
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")
elif feature_length == "var_length":
if data_type == "int64_list":
feature_description[feature_name] = tf.io.VarLenFeature(tf.int64)
elif data_type == "float32_list":
# TODO
pass
elif data_type == "string_list":
# TODO
pass
else:
raise ValueError(f"var_length datatype {data_type} now is not supported!")
else:
# TODO
raise ValueError(f"feature_length {feature_length} now is not supported!")
# 2.解析label配置
for label_name, label_stats in feature_config["label"].items():
if label_stats["is_use"] != 1:
continue
feature_length = label_stats["feature_length"]
data_type = label_stats["data_type"]
feature_shape = label_stats["shape"][1]
default_value = label_stats["default_value"]
assert data_type == "int64"
assert feature_length == "fixed_length"
assert feature_shape == 1
feature_description[label_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))

# 3.解析context配置
for context_name, context_stats in feature_config["context"].items():
if context_stats["is_use"] != 1:
continue
feature_length = context_stats["feature_length"]
feature_shape = context_stats["shape"][1]
default_value = context_stats["default_value"]
data_type = context_stats["data_type"]
assert feature_shape == 1
if feature_length == "fixed_length":
if data_type == "int64":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))
elif data_type == "float32":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=float(default_value))
elif data_type == "string":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=str(default_value))
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")

return feature_description

feature_description = get_feature_description(aa)
# 定义解析函数
def parse_tfrecord_fn(example):
parsed_example = tf.io.parse_single_example(example, feature_description)

return parsed_example


# 指定要解析的TFRecord文件路径
tfrecord_file = './data/train_test.tfrecord'

# 创建TFRecordDataset对象并应用解析函数
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(parse_tfrecord_fn)

# 遍历数据集并打印样本
for example in dataset:
print("binary_label:", example['binary_label'].numpy())
print("four_label:", example['four_label'].numpy())
print("item_code:", example['item_code'].numpy())
print("query:", example['query'].numpy())
print("query_attention_mask:", example['query_attention_mask'].numpy())
print("query_input_ids:", example['query_input_ids'].numpy())
print("query_token_type_ids:", example['query_token_type_ids'].numpy())
print("title_attention_mask:", example['title_attention_mask'].numpy())
print("title_input_ids:", example['title_input_ids'].numpy())
print("title_token_type_ids:", example['title_token_type_ids'].numpy())
print("triple_label:", example['triple_label'].numpy())
sys.exit(1)




标签:tfrecord,default,feature,value,length,tf,TensorFlow,解析,data
From: https://www.cnblogs.com/qiaoqifa/p/17839364.html

相关文章

  • wcf restful 用stream接收表单数据并解析
    1.下载包HttpMultipartParser 2.服务端代码publicboolUpload(Streamstream){varparser=MultipartFormDataParser.Parse(stream);//解析streamvarfile=parser.Files.First();//获取文件stringfilename=file.Fi......
  • 实例解析html页面语言
    清晰的了解html代码表达的意思才能准确的通过代码展示出开发者的设计思路。这里总结了一些常见的的页面代码,逐行解释其表达的意思,以备能随时翻阅,常备常练。示例资料<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><metaname="viewport"content="width=......
  • CreateCollection_dataSyncService_执行流程源码解析
    CreateCollection_dataSyncService_执行流程源码解析milvus版本:v2.3.2CreateCollection这个API流程较长,也是milvus的核心API之一,涉及的内容比较复杂。这里介绍dataSyncService相关的流程。这边文章基于【CreateCollection流程_addCollectionMetaStep_milvus源码解析】这篇文章......
  • Core 6.0 webapi ‘报错InvalidOperationException:无法解析“ Microsoft.AspNetCore.H
    因接口版本升级并使用core6.0却发现HttpContext.Current.Request用不了 所以在网上找了半天说是使用Microsoft.AspNetCore.Http.IHttpContextAccessorprivateIHttpContextAccessor_httpContextAccessor;publicWebHelper(IHttpContextAccessorhttpContextAccessor......
  • 源码解析axios拦截器
    从源码解析axios拦截器是如何工作的axios拦截器的配置方式axios中有两种拦截器:axios.interceptors.request.use(onFulfilled,onRejected,options):配置请求拦截器。onFulfilled方法在发送请求前执行,接收config对象,返回一个新的config对象,可在此方法内修改config对......
  • JAVA解析Excel文件 + 多线程 + 事务回滚
    1.项目背景:客户插入Excel文件,Ececel文件中包含大量的数据行和数据列,单线程按行读取,耗时大约半小时,体验感不好。思路:先将excel文件按行读取,存入List,然后按照100均分,n=list.szie()/100+1;n就是要开启的线程总数。(实际使用的时候,数据库连接池的数量有限制,n的大小要结合数据库连......
  • JAVA 解析Excel + 多线程 + 事务回滚(2)
    该方法为网上查询,感觉可行,并未真正尝试。主线程:packagecom.swagger.demo.service;importcom.alibaba.excel.context.AnalysisContext;importcom.alibaba.excel.event.AnalysisEventListener;importcom.swagger.demo.config.SpringJobBeanFactory;importcom.swagger.demo.m......
  • 解析json
    result.SetSuccess(Util.TryGetJSONObject<JObject>("{\"obj\":{\"reply\":\""+row.response+"\"},\"code\":"+0+"}")); {"Success":true,"Message&......
  • 学习笔记426—keras中to_categorical函数解析
    keras中to_categorical函数解析1.to_categorical的功能简单来说,to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示。其表现为将原有的类别向量转换为独热编码的形式。先上代码看一下效果:fromkeras.utils.np_utilsimport*#类别向量定义b=[0,1,2,3,4,5,6,7,8]......
  • bulk批量操作的json格式解析
    3.17bulk批量操作的json格式解析bulk的格式:{action:{metadata}}\n{requstbody}\n为什么不使用如下格式:[{"action":{},"data":{}}]这种方式可读性好,但是内部处理就麻烦了:1.将json数组解析为JSONArray对象,在内存中就需要有一份json文本的拷贝,另外还有一个JSONArray对象。2.解析jso......