首页 > 其他分享 >以图搜图

以图搜图

时间:2022-11-15 10:34:14浏览次数:44  
标签:name 搜图 image time collection 以图 图像 import

以图搜图的基本原理:

  以图搜图是一种基于内容的图像检索 (CBIR) 技术²,它的特点是无需关键字就能理解图像的相关内容,主要依赖于 AI 算法,目前一些排名较好的图像分类算法可以到达 99% 准确率(TOP5)³。本文将利用 AI 模型提取图像特征向量,通过特征向量计算来完成以图搜图。

一 ,Towhee & Milvus

  Towhee (http://github.com/towhee-io/towhee)提供开箱即用的 Embedding 流水线可以将任何非结构化数据(图像,视频,音频等)转为特征向量,通过 Towhee 我们运行一条流水线就能轻松得到特征向量。

  Milvus(http://github.com/milvus-io/milvus) 是一个开源的向量数据库项目,它支持丰富的向量索引算法和向量计算方式,轻松实现对数百万、数十亿甚至数万亿向量的相似性搜索,具有高度灵活、稳定可靠以及高速查询等特点。

  通过 Towhee + Milvus 就可以实现端到端的图像等非结构化数据分析。我们先使用 Towhee 完成非结构化数据的特征向量提取,然后 Milvus 负责存储并搜索向量,最终获取与查询数据最相似的结果并展示。

Towhee 和 Milvus 的安装:

  注意:Milvus 支持单机安装和集群安装,本文使用docker-compose(http://milvus.io/docs/v2.0.x/install_standalone-docker.md)方式安装单机 Milvus,在此之前请先检查本机环境的软硬件条件(http://milvus.io/docs/v2.0.x/prerequisite-docker.md)。

#安装 Towhee

$ pip install towhee

#安装单机版 Milvus
$ wget http://github.com/milvus-io/milvus/releases/download/v2.0.2/milvus-standalone-docker-compose.yml -O docker-compose.yml
$ docker-compose up -d

 

  Towhee 支持图像 Embedding,音频 Embedding,视频 Embedding 等非结构化数据特征提取的方法,这些都被称为 Towhee 的算子(Operator),算子是流水线(Pipeline)中的单个节点,一个图像特征提取流水线就可以通过连接 image_decode(http://towhee.io/image-decode/cv2) 算子和 image_embedding.timm(http://towhee.io/image-embedding/timm) 算子实现,其中 Embedding 算子可以通过指定model_name="resnet50"利用 ResNet50 模型生成特征向量

代码:

import towhee
towhee.glob['path']('./test/lion/n02129165_13728.JPEG') \
      .image_decode['path', 'img']() \
      .image_embedding.timm['img', 'vec'](model_name='resnet50') \
      .select['img', 'vec']() \
      .show()

  接下来在 Milvus 数据库中创建集合(Collection),集合中的 Fields 包含两列:id 和 embedding,其中 id 是集合的主键。另外我们可以为 embedding 创建 IVF_FLAT (http://milvus.io/docs/v2.0.x/index.md#IVF_FLAT) 基于量化的索引,其中索引的参数是 nlist=2048,计算方式是 "L2" 欧式距离:

代码:

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
def create_milvus_collection(collection_name, dim):
    connections.connect(host='127.0.0.1', port='19530')
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)
    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection
collection = create_milvus_collection('reverse_image_search', 2048)

图像数据入库

  Towhee 不光拥有丰富的算子来处理非结构化数据,还提供了简单好用的接口来处理各种数据,当然也集成了 Milvus 的一些基本用法,通过在“流水线”中连接这些算子或接口,图像入库操作将变得十分Milvus简单。

 

import towhee 
dc = ( 
    towhee.read_csv('reverse_image_search.csv') #读取 CSV 格式的表格,包含了 id,path 和 label 列
 .runas_op['id', 'id'](func=lambda x: int(x)) #将每一行的 id 从 str 类型转为 int 类型
 .image_decode['path', 'img']() #读取每一行 path 对应的图像,并将其解码为 Towhee 的图像格式 
 .image_embedding.timm['img', 'vec'](model_name='resnet50') #提取特征向量
 .tensor_normalize['vec', 'vec']() #将向量进行归一化
 .to_milvus['id', 'vec'](collection=collection, batch=100) #将 id 和 vec 批量 100 条插入到 Milvus 集合
)

查询图像并展示

  查询图像时需要的图像处理算子与前面类似,包括image_decodeimage_embedding.timmtensor_normalize,而在最后分析检索结果时,需用到数据准备部分定义好的read_images函数,通过指定runas_op中的func将该函数加入到 Towhee 流水线中。

 

(towhee.glob['path']('./test/w*/*.JPEG') #读取满足指定模式下的所有图片数据为 path 
 .image_decode['path', 'img']() #读取每一行 path 对应的图像,并将其解码为 Towhee 的图像格式  
 .image_embedding.timm['img', 'vec'](model_name='resnet50') #提取特征向量
 .tensor_normalize['vec', 'vec']() #将向量进行归一化
 .milvus_search['vec', 'result'](collection=collection, limit=5) #在 Milvus 集合中搜索向量,并返回结果
 .runas_op['result', 'result_img'](func=read_images) #处理 Milvus 的检索结果,最终返回图像用于展示
 .select['img', 'result_img']() #选择指定列; 
 .show()
)

二,

1,选用resnet网络提取图像特征

2,milvus建表,用milvus存放图像特征,通过唯一ID(此处称:milvus_id)与图像一一对应,sql建表将milvus_id作为唯一索引,存放图像的其他信息

3,异步添加图像,同步搜索图像,添加图像的量通常会很大,因此采用异步批量的方式将图像特征加载到milvus,图像添加服务会将每次的请求信息存到sql,写个脚本专门用来定时批量加载图像特征到milvus,由于是异步操作,可能会出现重复加载的情况,此处使用redis进行去重。图像搜索的请求通常会比图像添加少很多,因此图像搜索使采用同步方式返回结果;

(总结:需建立三个表:milvus表1,存放图像特征;sql表2,存放图像信息,数据与milvus表1一一对应;sql表3,存放图像添加请求信息,用于图像特征异步批量加载到milvus)

图像向量化

"""

功能:图像向量化 """ from keras.applications.resnet50 import ResNet50 from keras.preprocessing import image from keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np from numpy import linalg as LA import time   model = ResNet50(weights='imagenet') # model.summary()     def img2feature(img_path, input_dim=224):  # 图像路径???图像数据     img = image.load_img(img_path, target_size=(input_dim, input_dim))     = image.img_to_array(img)     = np.expand_dims(x, axis=0)     = preprocess_input(x)     = model.predict(x)     = / LA.norm(x)     return x     def main():     img_path = '1.jpg'     t0 = time.time()     res = img2feature(img_path)     print(time.time() - t0, res.shape)     # print(res, type(res), res.shape)     if __name__ == "__main__":     main()  

milvus表的操作

# coding:utf-8 from functools import reduce import numpy as np import time from img2feature import img2feature from pymilvus import (     connections, list_collections,     FieldSchema, CollectionSchema, DataType,     Collection, utility )     field_name = 'image_feature' host = '***.***.***.***' port = '19530' dim = 1000 default_fields = [     FieldSchema(name="milvus_id", dtype=DataType.INT64, is_primary=True),     FieldSchema(name="feature", dtype=DataType.FLOAT_VECTOR, dim=dim),     FieldSchema(name="create_time", dtype=DataType.INT64) ]     # create_table def create_table():     connections.connect(host=host, port=port)     # create collection       default_schema = CollectionSchema(fields=default_fields, description="test collection")       print(f"\nCreate collection...")     collection = Collection(name=field_name, schema=default_schema)     print(f"\nCreate index...")     default_index = {"index_type""FLAT""params": {"nlist"128}, "metric_type""L2"}     collection.create_index(field_name="feature", index_params=default_index)     print(print(f"\nCreate index...is OKOKOKOKOK"))     collection.load()     # insert data def insert_data():     connections.connect(host=host, port=port)     default_schema = CollectionSchema(fields=default_fields, description="test collection")     collection = Collection(name=field_name, schema=default_schema)     vectors = img2feature('1.jpg').tolist()[0]     print(type(vectors), len(vectors))     data1 = [         [123],         [vectors],         [int(time.time())]     ]     collection.insert(data1)     print('insert compete')     # search data def search_data():     print('search')     connections.connect(host=host, port=port)     collection = Collection(name=field_name)     print('连接成功')       # 首次查询建立索引和load()     # default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}     # print(f"\nCreate index...")     # collection.create_index(field_name="feature", index_params=default_index)     # print(print(f"\nCreate index...is OKOKOKOKOK"))     # collection.load()     # exit()       vectors = img2feature('1.jpg').tolist()[0]       topK = 10     search_params = {"metric_type""L2""params": {"nprobe"10}}       res = collection.search(         [vectors],         "feature",         search_params,         topK,         "create_time > {}".format(0),         output_fields=["milvus_id"]     )     print('>>>', res)     for hits in res:         print(len(hits))         for hit in hits:             print(hit)     print('查询结束')     def show_nums():     connections.connect(host=host, port=port)     collection = Collection(name=field_name)     print('ok')     print(collection.num_entities)     # delete data def delete_table():     connections.connect(host=host, port=port)     default_schema = CollectionSchema(fields=default_fields, description="test collection")     collection = Collection(name=field_name, schema=default_schema)     print('>>>', utility.has_collection(field_name))     collection.drop()     print('>>>', utility.has_collection(field_name))     if __name__ == "__main__":     t1 = time.time()     # create_table()     # insert_data()     # search_data()     show_nums()     # delete_table()     print('time cost: {}'.format(time.time() - t1))  

图像添加、搜索服务

from rest_framework.views import APIView as View from kpdjango.response import SucessAPIResponse, ErrorAPIResponse from kpmysql.base import Kpmysql from core import search_image import kplog import logging log = logging.getLogger("console")     class add_image(View):     def post(self, requests):         try:             db = Kpmysql.connect("db168")             cur = db.cursor()             image_info = requests.POST.get('image_info')             image_path = requests.POST.get('image_path')             sql = "INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)"             cur.execute(sql, (image_path, image_info))             db.commit()             log.info('添加图像成功:{}-{}'.format(image_path, image_info))             return SucessAPIResponse(msg="Success")         except Exception as e:             log.info('添加图像失败:{}'.format(e))             return ErrorAPIResponse(msg="Fail")     class search_image(View):     def post(self, requests):         try:             image_path = requests.POST.get('image_path')             res = search_image(image_path)             log.info('查询图像成功:{}-{}'.format(image_path, res))             return SucessAPIResponse(msg="Success", data={"data": res})         except Exception as e:             log.info('查询图像成功:{}'.format(e))             return ErrorAPIResponse(msg="Fail")

图像异步批量加载

import time, datetime from kpmysql.base import Kpmysql from core import insert_data_many from concurrent.futures import ThreadPoolExecutor import redis from conf.setting import REDIS from core import str2time import kplog import logging   log = logging.getLogger("console") log_addimgs = logging.getLogger("console_addimgs")     def worker(datas):     try:         redis_cli = redis.Redis(host=REDIS.get('host'), port=REDIS.get('port'), password=REDIS.get('password'),                                 db=REDIS.get('db'))         dics = []         ids = []         for data in datas:             if redis_cli.zscore('image_search'str(data[0])):  # 基于redis去重                 continue             dics.append({'image_path': data[1], 'create_time': data[2]})             ids.append((data[0]))             redis_cli.zadd('image_search', {str(data[0]): str2time(data[2])})         # 数据插入milvus         insert_data_many(dics)         # 更新 set t_image_search_image_add_log is_load=1         sql_update = """UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s"""         db168 = Kpmysql.connect("db168")         cur168 = db168.cursor()         cur168.executemany(sql_update, ids)         db168.commit()     except Exception as e:         print(e)     def main():     max_workers = 20  # 最大线程数     pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='Thread')     task_list = []     init_time = datetime.datetime.now() - datetime.timedelta(hours=13)     create_time_init = '2020-2-22 00:00:00'     while True:         now = datetime.datetime.now()         diff = now - init_time         if diff.seconds > 3600:             # 加载 t_image_search_image_add_log where is_load=0 数据             db168 = Kpmysql.connect("db168")             cur168 = db168.cursor()             sql = """SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time"""             cur168.execute(sql, create_time_init)             datas = cur168.fetchall()             create_time_init = datas[-1][2]                 while True:                 for _i, _n in enumerate(task_list):                     if _n.done():                         task_list.pop(_i)                 if len(task_list) < int(max_workers * 0.9):                     break             task_list.append(pool.submit(worker, datas))             init_time = now         time.sleep(600)     if __name__ == "__main__":     main()  

优化

1. keras在调用GPU时并开启多线程时不如pytorch方便,pytorch占用显存更少;

2. 定时从数据库拿数据,改成kafka生产消费模型,代码更简洁,逻辑更简单;

 

三, 还有一些获取图片特征的VGG和Milvus组合使用:

参考:https://cloud.tencent.com/developer/article/1605032

 

参考:

1,https://maimai.cn/article/detail?fid=1743956531&efid=sTnHYzKAy8MK8AhgjSi7Bg

2,https://www.cnblogs.com/niulang/p/15921786.html

 

 

标签:name,搜图,image,time,collection,以图,图像,import
From: https://www.cnblogs.com/zwbsoft/p/16891539.html

相关文章

  • 数据是如何以图表形式呈现的?
    人都是视觉动物,凡事先入眼,再过脑,这是一个既定的流程,但“视觉效果”也就是可视化大屏相较于以往单一的数据表格最出彩的地方。 人的大脑和计算机一样分为长期记忆和短期......
  • js 以图片上的某个位置来缩放这个图片
    1"usestrict";23var__emptyPoint=null,__emptyContext=null,__emptyPointA=null;45constColorRefTable={6"aliceblue":......
  • 以图搜图
    ///<summary> ///感知哈希算法 ///</summary> publicclassImageComparer { ///<summary> ///获取图片的Hashcode ///</summary> ///<paramname="imageNam......
  • 生成二维码并以图片格式下载-qrcodejs2
    1、安装qrcodejs2npminstallqrcodejs2--save2、在需要的页面引入importQRCodefrom"qrcodejs2";3、页面中使用<divid="qrcode"ref="qrcode"></div>4......
  • php以图形方式显示中文,指定ttf字库
    1<?php2header("Content-Type:image/png");3$img=imagecreatetruecolor(400,300);4//imagejpeg($img);5//imagejpeg($img,"./img/copy_img01.jpg",10);......