Chinaunix首页 | 论坛 | 博客
  • 博客访问: 46279
  • 博文数量: 4
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 70
  • 用 户 组: 普通用户
  • 注册时间: 2024-05-13 18:19
文章分类
文章存档

2024年(4)

我的朋友

分类: 大数据

2024-05-24 11:36:21

本文主要介绍通过DashVector和ModelScope中的Chinese Clip模型实现文搜图、图搜图等功能,同时结合DashText SDK实现sparse vector+dense vector混合检索,熟悉sparse vector的使用方法,提高检索效率。

1 准备工作

1.1 基本概念

  1. Chinese Clip:为模型的中文版本,使用大规模中文数据进行训练(~2亿图文对),可用于图文检索和图像、文本的表征提取,应用于搜索、推荐等应用场景。详情请参考:
  2. DashVector:向量检索服务基于阿里云自研的向量引擎 Proxima 内核,提供具备水平拓展、全托管、云原生的高效向量检索服务。向量检索服务将强大的向量管理、查询等能力,通过简洁易用的 SDK/API 接口透出,方便在大模型知识库搭建、多模态 AI 搜索等多种应用场景上集成。详情请参考:
  3. MUGE数据集:MUGE(牧歌,Multimodal Understanding and Generation Evaluation)是业界{BANNED}首选大规模中文多模态评测基准,由达摩院联合浙江大学、阿里云天池平台联合发布,中国计算机学会计算机视觉专委会(CCF-CV专委)协助推出。目前包括: · 包含多模态理解与生成任务在内的多模态评测基准,其中包括图像描述、图文检索以及基于文本的图像生成。未来我们将公布更多任务及数据。 · 公开的评测榜单,帮助研究人员评估模型和追踪进展。 MUGE旨在推动多模态表示学习进展,尤其关注多模态预训练。具备多模态理解和生成能力的模型均可以参加此评测,欢迎各位与我们共同推动多模态领域发展。详情请参考:
  4. DashText是向量检索服务DashVector推荐使用的稀疏向量编码器(Sparse Vector Encoder),DashText可通过BM25算法将原始文本转换为稀疏向量(Sparse Vector)表达,通过DashText可大幅度简化使用DashVector关键词感知检索能力。详情请参考:https://help.aliyun.com/document_detail/2546039.html

1.2 准备工作

  1. 获取DashVector的API-KEY。API-KEY用于访问DashVector服务,详请参考:https://help.aliyun.com/document_detail/2510230.html

  2. 申请DashVector测试实例,DashVector提供免费试用实例,可以薅一波。详情请见:https://help.aliyun.com/document_detail/2568083.html
  3. 获取DashVector实例的endpoint,endpoint用于访问DashVector具体的实例。详情请见:https://help.aliyun.com/document_detail/2568084.html

  4. 安装DashVector、DashText、ModelScope的SDK

点击(此处)折叠或打开

  1. pip install dashvector
  2. pip install dashtext
  3. pip install modelscope
由于安装ModelScope SDK需要一些依赖,继续安装,安装的时间有点长,请耐心等待~~~~~

点击(此处)折叠或打开

  1. pip install decord
  2. pip install torch torchvision opencv-python timm librosa fairseq transformers unicodedata2 zhconv rapidfuzz
由于本教程中,会使用DashText的sdk生成稀疏向量,生成稀疏向量过程中会先下载一个词包,下载过程比较长。所以可以预先下载。

点击(此处)折叠或打开

  1. wget https://dashvector-data.oss-cn-beijing.aliyuncs.com/public/sparsevector/bm25_zh_default.json
好啦,SDK和依赖都安装完了,下面简单介绍一下多模态搜索的过程。

1.3 多模态搜索过程

  1. 多模态搜索分为两个过程,即索引过程和搜索过程。
  2. 索引过程:本教程在索引过程中,使用MUGE数据集,数据格式如下。只需要对MUGE数据集中的图片和文本提取特征,然后将特征插入到DashVector中,就完成了索引过程。
  1. [{
  2.     "query_id": "54372",
  3.     "query": "金属产品打印",
  4.     "image_id": "813904",
  5.     "image": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F8EB1F39DB0>
  6. },
  7. {
  8.     "query_id": "78633",
  9.     "query": "夹棉帽子",
  10.     "image_id": "749842",
  11.     "image": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F8EB0AFFF70>
  12. }]
  1. 搜索过程:通过对输入的文本或者图片,提取特征,并通过特征在DashVector中已经索引的向量中进行相似向量查询,并将查询后的结果解析成可视化的图片和文本,即完成了搜索过程。详情请看下图。

2 创建DashVector Collection

  1. from dashvector import Client
  2. # 如下填写您在1.2 准备工作中获取的DashVector API-KEY
  3. DASHVECTOR_API_KEY = '{YOUR DashVector API-KEY}'
  4. # 如下填写您在1.2 准备工作中获取的DashVector中Cluster中的Endpoint
  5. DASHVECTOR_END_POINT='{YOUR DashVector Endpoint}'
  6. # 初始化DashVector 的client
  7. client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_END_POINT)
  8. response = client.create(
  9.     # Collection的名称,名称可自定义。这里暂时定义为:ImageTextSearch
  10.     name='ImageTextSearch',
  11.     # 创建Collection的维度,注意一定是1024维。因为后面我们会使用Chinese Clip模型进行embedding,Chinese Clip模型的输出维度是1024维。
  12.     dimension=1024,
  13.     # 距离度量方式一定为dotproduct,因为稀疏向量只支持dotproduc这种度量方式。
  14.     metric='dotproduct',
  15.     dtype=float,
  16.     # 定义schema,通过schema可以定义Collection中包含哪些字段,以及字段的类型,以便实现更快速的搜索。这里定义了image_id、query和query_id三个schema。
  17.   # 关于Schema的详细使用请参考:https://help.aliyun.com/document_detail/2510228.html
  18.     fields_schema={'image_id': int, 'query': str, 'query_id': int}
  19. )
  20. print(response)

好啦,Collection创建成功了。

3 图片和文本索引

  1. 图片和文本插入,由于涉及到图片特征提取,所以速度会有点慢,建议使用GPU进行特征提取。
  1. # 首先import一大堆东西
  2. from modelscope.msdatasets import MsDataset
  3. from modelscope.utils.constant import Tasks
  4. from modelscope.pipelines import pipeline
  5. import base64
  6. import io
  7. from dashvector import Client, Doc, DashVectorCode, DashVectorException
  8. from dashtext import SparseVectorEncoder
  9. # load 数据集,选取modelscope中的muge数据集,取数据集中validation部分
  10. # muge数据集的格式为:
  11. # [{
  12. # "query_id": "54372",
  13. # "query": "金属产品打印",
  14. # "image_id": "813904",
  15. # "image": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F8EB1F39DB0>
  16. # },
  17. # {
  18. # "query_id": "78633",
  19. # "query": "夹棉帽子",
  20. # "image_id": "749842",
  21. # "image": <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7F8EB0AFFF70>
  22. # }]
  23. # 首次load muge数据集有点慢,请耐心等待。
  24. datasets = MsDataset.load("muge", split="validation")
  25. # 获取数据集的长度,也就是数据集中有多少个这样的数据
  26. datasets_len = len(datasets)
  27. # 初始化ModelScope推理pipeline,本教程使用Chinese Clip作为embedding模型。由于图片的Embedding比较消耗计算资源,所以推荐使用GPU进行推理加速。
  28. # 请参考:https://modelscope.cn/models/iic/multi-modal_clip-vit-huge-patch14_zh/summary
  29. pipeline = pipeline(task=Tasks.multi_modal_embedding,
  30.                     model='damo/multi-modal_clip-vit-huge-patch14_zh', model_revision='v1.0.0')
  31. # 初始化稀疏向量编码器,用于对muge数据集中的query进行稀疏向量编码,中文编码。详情请参考:https://help.aliyun.com/document_detail/2546039.html
  32. encoder = SparseVectorEncoder()
  33. # encoder初始化的时间比较长,主要原因在于稀疏向量编码器需要加载一个json文件,该文件比较大,需要下载。我们可以先下载完,保存在本地,直接加载,速度会快很多。
  34. # 下载链接:https://dashvector-data.oss-cn-beijing.aliyuncs.com/public/sparsevector/bm25_zh_default.json
  35. # 也可以使用:wget https://dashvector-data.oss-cn-beijing.aliyuncs.com/public/sparsevector/bm25_zh_default.json,直接下载到本地。
  36. # 下载完成后,放在本机目录中,本教程已经下载完成,放在根目录下。
  37. encoder_path = 'bm25_zh_default.json'
  38. encoder.load(encoder_path)
  39. # 如下填写您在1.2 准备工作中获取的DashVector API-KEY
  40. DASHVECTOR_API_KEY = '{YOUR DashVector API-KEY}'
  41. # 如下填写您在1.2 准备工作中获取的DashVector中Cluster中的Endpoint
  42. DASHVECTOR_END_POINT='{YOUR DashVector Endpoint}'
  43. # 初始化dashvector的Client,用于访问dashvector服务
  44. # 请参考:https://help.aliyun.com/document_detail/2510240.html
  45. client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_END_POINT)
  46. # 将图片转成字符串,用于将图片存储在dashvector中
  47. def image_to_str(image):
  48.     image_byte_arr = io.BytesIO()
  49.     image.save(image_byte_arr, format='PNG')
  50.     image_bytes = image_byte_arr.getvalue()
  51.     return base64.b64encode(image_bytes).decode()
  52. # 通过Chinese Clip提取图片特征,并转成向量
  53. def image_vector(image):
  54.     # 通过Chinese Clip提取图片特征,返回为一个tensor
  55.     img_embedding = pipeline.forward({'img': image})['img_embedding']
  56.     # 将返回的tensor转成向量,向量需要转存到cpu中
  57.     img_vector = img_embedding.detach().cpu().numpy()
  58.     return img_vector if isinstance(image, list) else img_vector[0]
  59. # 通过Chinese Clip提取文本特征,并转成向量
  60. def text_vector(text):
  61.     # 通过Chinese Clip提取文本特征,返回为一个tensor
  62.     text_embedding = pipeline.forward({'text': text})['text_embedding']
  63.     # 将返回的tensor转成向量,向量需要转存到cpu中
  64.     text_vector = text_embedding.detach().cpu().numpy()
  65.     return text_vector if isinstance(text, list) else text_vector[0]
  66. # 通过dashtext对文本生成稀疏向量。注意,本函数为生成入库的稀疏向量,而不是query的稀疏向量
  67. def sparse_vector_documents(text):
  68.     # 通过dashtext生成稀疏向量
  69.     sparse_vector = encoder.encode_documents(text)
  70.     return sparse_vector if isinstance(text, list) else sparse_vector
  71. # 插入向量数据,batch_size默认为10,{BANNED}{BANNED}最佳佳大不超过20
  72. def insert_docs(collection_name: str, partition='default', batch_size=10):
  73.     idx = 0
  74.     while idx < datasets_len:
  75.         # 获取batch range数据
  76.         batch_range = range(idx, idx + batch_size) if idx + batch_size < datasets_len else range(idx, datasets_len)
  77.         # 获取image信息
  78.         images = [datasets[i]['image'] for i in batch_range]
  79.         # 通过Chinese Clip提取图片特征,返回为一个vector
  80.         images_vector = image_vector(images)
  81.         # 获取query信息
  82.         texts = [datasets[i]['query'] for i in batch_range]
  83.         # 生成稀疏向量
  84.         documents_sparse_vector = sparse_vector_documents(texts)
  85.         # 获取图片ID和query ID
  86.         images_ids = [datasets[i]['image_id'] for i in batch_range]
  87.         query_ids = [datasets[i]['query_id'] for i in batch_range]
  88.         # 获取Collection
  89.         collection = client.get(collection_name)
  90.         # 批量插入
  91.         response = collection.upsert(
  92.             [
  93.                 Doc(
  94.                     id=image_id,
  95.                     vector=img_vector,
  96.                     sparse_vector=document_sparse_vector,
  97.                     fields={
  98.                         # 由于在创建Collection时,image_id和query_id都是int类型,所以这里需要转换为int类型
  99.                         'image_id': int(image_id),
  100.                         'query_id': int(query_id),
  101.                         'query': query,
  102.                         # 将Image格式转成字符串,用于存储在dashvector中
  103.                         'image': image_to_str(image)
  104.                     }
  105.                 ) for img_vector, document_sparse_vector, image_id, query_id, image, query in
  106.                 zip(images_vector, documents_sparse_vector, images_ids, query_ids, images, texts)
  107.             ]
  108.         )
  109.         print(response)
  110.         idx += batch_size
  111.     return response
  112. if __name__ == '__main__':
  113.     # 插入数据
  114.     response = insert_docs(collection_name='ImageTextSearch', batch_size=20)
  1. 向量插入后,就可以在DashVector控制台看到向量啦!



4 图片和文本搜索

  1. 图片插入成功后,即可进行图片和文本的跨模态搜索了,同样由于搜索过程中,涉及到图片特征提取,建议使用GPU进行。
  1. # 老规矩,先import一堆东西
  2. from modelscope.utils.constant import Tasks
  3. from modelscope.preprocessors.image import load_image
  4. from modelscope.pipelines import pipeline
  5. from PIL import Image
  6. import base64
  7. import io
  8. from dashvector import Client, Doc, DashVectorCode, DashVectorException
  9. from dashtext import SparseVectorEncoder, combine_dense_and_sparse
  10. from urllib.parse import urlparse
  11. # 初始化ModelScope推理pipeline,本教程使用Chinese Clip作为embedding模型。由于图片的Embedding比较消耗计算资源,所以推荐使用GPU进行推理加速。
  12. # 请参考:https://modelscope.cn/models/iic/multi-modal_clip-vit-huge-patch14_zh/summary
  13. pipeline = pipeline(task=Tasks.multi_modal_embedding,
  14.                     model='damo/multi-modal_clip-vit-huge-patch14_zh', model_revision='v1.0.0')
  15. # 初始化稀疏向量编码器,用于对muge数据集中的query进行稀疏向量编码,中文编码。详情请参考:https://help.aliyun.com/document_detail/2546039.html
  16. encoder = SparseVectorEncoder()
  17. # encoder初始化的时间比较长,主要原因在于稀疏向量编码器需要加载一个json文件,该文件比较大,需要下载。我们可以先下载完,保存在本地,直接加载,速度会快很多。
  18. # 下载链接:https://dashvector-data.oss-cn-beijing.aliyuncs.com/public/sparsevector/bm25_zh_default.json
  19. # 也可以使用:wget https://dashvector-data.oss-cn-beijing.aliyuncs.com/public/sparsevector/bm25_zh_default.json,直接下载到本地。
  20. # 下载完成后,放在本机目录中,本教程已经下载完成,放在根目录下。
  21. encoder_path = 'bm25_zh_default.json'
  22. encoder.load(encoder_path)
  23. # 如下填写您在1.2 准备工作中获取的DashVector API-KEY
  24. DASHVECTOR_API_KEY = '{YOUR DashVector API-KEY}'
  25. # 如下填写您在1.2 准备工作中获取的DashVector中Cluster中的Endpoint
  26. DASHVECTOR_END_POINT='{YOUR DashVector Endpoint}'
  27. # 初始化dashvector的Client,用于访问dashvector服务
  28. # 请参考:https://help.aliyun.com/document_detail/2510240.html
  29. client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_END_POINT)
  30. # 将字符串转为图片
  31. def str2image(image_str):
  32.     image_bytes = base64.b64decode(image_str)
  33.     return Image.open(io.BytesIO(image_bytes))
  34. # 判断是否为URL
  35. def is_url(url):
  36.     try:
  37.         result = urlparse(url)
  38.         return all([result.scheme, result.netloc])
  39.     except ValueError:
  40.         return False
  41. # 通过Chinese Clip提取图片特征,并转成向量
  42. def image_vector(image):
  43.     # 通过Chinese Clip提取图片特征,返回为一个tensor
  44.     img_embedding = pipeline.forward({'img': image})['img_embedding']
  45.     # 将返回的tensor转成向量,向量需要转存到cpu中
  46.     img_vector = img_embedding.detach().cpu().numpy()
  47.     return img_vector if isinstance(image, list) else img_vector[0]
  48. # 通过Chinese Clip提取文本特征,并转成向量
  49. def text_vector(text):
  50.     # 通过Chinese Clip提取文本特征,返回为一个tensor
  51.     text_embedding = pipeline.forward({'text': text})['text_embedding']
  52.     # 将返回的tensor转成向量,向量需要转存到cpu中
  53.     text_vector = text_embedding.detach().cpu().numpy()
  54.     return text_vector if isinstance(text, list) else text_vector[0]
  55. # 通过dashtext对文本生成稀疏向量。注意,本函数为query的稀疏向量,而不是入库的稀疏向量
  56. def sparse_vector_queries(text):
  57.     # 通过dashtext生成稀疏向量
  58.     sparse_vector = encoder.encode_queries(text)
  59.     return sparse_vector if isinstance(text, list) else sparse_vector
  60. # 通过文本和图片搜索图片,返回搜索结果。其中,文本会转换为稀疏向量,图片会转换成稠密向量,并通过alpha值控制稠密向量和稀疏向量的权重,alpha=1.0则全部使用稠密向量搜索,alpha=0.0则全部使用稀疏向量搜索
  61. def serach_by_imageAndtext(query_text, query_image, collection_name, partition='default', top_k=10, alpha=0.5):
  62.     if is_url(query_image):
  63.         query_image = load_image(query_image)
  64.     image_embedding = image_vector(query_image)
  65.     query_sparse_embedding = sparse_vector_queries(query_text)
  66.     scaled_dense_vector, scaled_sparse_vector = combine_dense_and_sparse(image_embedding, query_sparse_embedding, alpha)
  67.     try:
  68.         collection = client.get(name=collection_name)
  69.         # 搜索
  70.         docs = collection.query(
  71.             vector=scaled_dense_vector,
  72.             sparse_vector=scaled_sparse_vector,
  73.             partition=partition,
  74.             topk=top_k,
  75.             output_fields=['image', 'query', 'image_id']
  76.         )
  77.         image_list = list()
  78.         for doc in docs:
  79.             image_str = doc.fields['image']
  80.             # print(doc.score)
  81.             # print(doc.fields['query'])
  82.             # print(doc.fields['image_id'])
  83.             image_list.append(str2image(image_str))
  84.         return image_list
  85.     except DashVectorException as e:
  86.         print(e)
  87.         return []
  88. # 通过文本搜索图片,返回搜索结果,并将文本变成对应的稀疏向量和稠密向量,稀疏向量用来控制文本中是否包含该关键词,稠密向量用于控制图片中是否包含此信息。可通过alpha值控制稠密向量和稀疏向量的权重,alpha=1.0则全部使用稠密向量搜索,alpha=0.0则全部使用稀疏向量搜索
  89. def search_by_text(query_text, collection_name, partition='default', top_k=10, alpha=0.5):
  90.     query_embedding = text_vector(query_text)
  91.     print(query_embedding)
  92.     print(type(query_embedding))
  93.     print(query_embedding.dtype)
  94.     query_sparse_embedding = sparse_vector_queries(query_text)
  95.     scaled_dense_vector, scaled_sparse_vector = combine_dense_and_sparse(query_embedding, query_sparse_embedding, alpha)
  96.     try:
  97.         collection = client.get(name=collection_name)
  98.         # 搜索
  99.         docs = collection.query(
  100.             vector=scaled_dense_vector,
  101.             sparse_vector=scaled_sparse_vector,
  102.             partition=partition,
  103.             topk=top_k,
  104.             output_fields=['image', 'query', 'image_id']
  105.         )
  106.         image_list = list()
  107.         for doc in docs:
  108.             image_str = doc.fields['image']
  109.             # print(doc.score)
  110.             # print(doc.fields['query'])
  111.             # print(doc.fields['image_id'])
  112.             image_list.append(str2image(image_str))
  113.         return image_list
  114.     except DashVectorException as e:
  115.         print(e)
  116.         return []
  117. if __name__ == '__main__':
  118.     query_text = '女士帽子'
  119.     query_image = '!!2217497569457-0-cib.jpg?__r__=1711033209457'
  120.     # response = search_by_text(query_text=query_text, collection_name='ImageTextSearch', alpha=1.0)
  121.     response = serach_by_imageAndtext(query_text=query_text, query_image=query_image, collection_name='ImageTextSearch',
  122.                                       top_k=20, alpha=0.8)
  123.     for image in response:
  124.         image.show()
  1. 搜索结果出来啦!

阅读(385) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~